From: <fwi...@us...> - 2006-05-26 02:15:51
|
Revision: 2760 Author: fwierzbicki Date: 2006-05-25 19:15:40 -0700 (Thu, 25 May 2006) ViewCVS: http://svn.sourceforge.net/jython/?rev=2760&view=rev Log Message: ----------- Start of 2.2 level pickle support. Modified Paths: -------------- trunk/jython/Lib/test/test_datetime.py trunk/jython/Lib/test/test_subclasses.py trunk/jython/src/org/python/core/PyInteger.java trunk/jython/src/org/python/core/PyIntegerDerived.java Added Paths: ----------- trunk/jython/Lib/pickle.py Added: trunk/jython/Lib/pickle.py =================================================================== --- trunk/jython/Lib/pickle.py (rev 0) +++ trunk/jython/Lib/pickle.py 2006-05-26 02:15:40 UTC (rev 2760) @@ -0,0 +1,990 @@ +"""Create portable serialized representations of Python objects. + +See module cPickle for a (much) faster implementation. +See module copy_reg for a mechanism for registering custom picklers. + +Classes: + + Pickler + Unpickler + +Functions: + + dump(object, file) + dumps(object) -> string + load(file) -> object + loads(string) -> object + +Misc variables: + + __version__ + format_version + compatible_formats + +""" + +__version__ = "$Revision: 1.56.4.4 $" # Code version + +from types import * +from copy_reg import dispatch_table, safe_constructors +import marshal +import sys +import struct +import re + +__all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler", + "Unpickler", "dump", "dumps", "load", "loads"] + +format_version = "1.3" # File format version we write +compatible_formats = ["1.0", "1.1", "1.2"] # Old format versions we can read + +mdumps = marshal.dumps +mloads = marshal.loads + +class PickleError(Exception): pass +class PicklingError(PickleError): pass +class UnpicklingError(PickleError): pass + +class _Stop(Exception): + def __init__(self, value): + self.value = value + +try: + from org.python.core import PyStringMap +except ImportError: + PyStringMap = None + +try: + UnicodeType +except NameError: + UnicodeType = None + + +MARK = '(' +STOP = '.' +POP = '0' +POP_MARK = '1' +DUP = '2' +FLOAT = 'F' +INT = 'I' +BININT = 'J' +BININT1 = 'K' +LONG = 'L' +BININT2 = 'M' +NONE = 'N' +PERSID = 'P' +BINPERSID = 'Q' +REDUCE = 'R' +STRING = 'S' +BINSTRING = 'T' +SHORT_BINSTRING = 'U' +UNICODE = 'V' +BINUNICODE = 'X' +APPEND = 'a' +BUILD = 'b' +GLOBAL = 'c' +DICT = 'd' +EMPTY_DICT = '}' +APPENDS = 'e' +GET = 'g' +BINGET = 'h' +INST = 'i' +LONG_BINGET = 'j' +LIST = 'l' +EMPTY_LIST = ']' +OBJ = 'o' +PUT = 'p' +BINPUT = 'q' +LONG_BINPUT = 'r' +SETITEM = 's' +TUPLE = 't' +EMPTY_TUPLE = ')' +SETITEMS = 'u' +BINFLOAT = 'G' + +__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)]) + +class Pickler: + + def __init__(self, file, bin = 0): + self.write = file.write + self.memo = {} + self.bin = bin + + def dump(self, object): + self.save(object) + self.write(STOP) + + def put(self, i): + if self.bin: + s = mdumps(i)[1:] + if i < 256: + return BINPUT + s[0] + + return LONG_BINPUT + s + + return PUT + `i` + '\n' + + def get(self, i): + if self.bin: + s = mdumps(i)[1:] + + if i < 256: + return BINGET + s[0] + + return LONG_BINGET + s + + return GET + `i` + '\n' + + def save(self, object, pers_save = 0): + memo = self.memo + + if not pers_save: + pid = self.persistent_id(object) + if pid is not None: + self.save_pers(pid) + return + + d = id(object) + + t = type(object) + + if (t is TupleType) and (len(object) == 0): + if self.bin: + self.save_empty_tuple(object) + else: + self.save_tuple(object) + return + + if memo.has_key(d): + self.write(self.get(memo[d][0])) + return + + try: + f = self.dispatch[t] + except KeyError: + pid = self.inst_persistent_id(object) + if pid is not None: + self.save_pers(pid) + return + + try: + # XXX: TypeType comparison broken in Jython so Kludging around this for now. + #issc = issubclass(t, TypeType) + issc = str(t) == "<type 'type'>" + except TypeError: # t is not a class + issc = 0 + if issc: + self.save_global(object) + return + + try: + reduce = dispatch_table[t] + except KeyError: + try: + reduce = object.__reduce__ + except AttributeError: + raise PicklingError, \ + "can't pickle %s object: %s" % (`t.__name__`, + `object`) + else: + tup = reduce() + else: + tup = reduce(object) + + if type(tup) is StringType: + self.save_global(object, tup) + return + + if type(tup) is not TupleType: + raise PicklingError, "Value returned by %s must be a " \ + "tuple" % reduce + + l = len(tup) + + if (l != 2) and (l != 3): + raise PicklingError, "tuple returned by %s must contain " \ + "only two or three elements" % reduce + + callable = tup[0] + arg_tup = tup[1] + + if l > 2: + state = tup[2] + else: + state = None + + if type(arg_tup) is not TupleType and arg_tup is not None: + raise PicklingError, "Second element of tuple returned " \ + "by %s must be a tuple" % reduce + + self.save_reduce(callable, arg_tup, state) + memo_len = len(memo) + self.write(self.put(memo_len)) + memo[d] = (memo_len, object) + return + + f(self, object) + + def persistent_id(self, object): + return None + + def inst_persistent_id(self, object): + return None + + def save_pers(self, pid): + if not self.bin: + self.write(PERSID + str(pid) + '\n') + else: + self.save(pid, 1) + self.write(BINPERSID) + + def save_reduce(self, callable, arg_tup, state = None): + write = self.write + save = self.save + + save(callable) + save(arg_tup) + write(REDUCE) + + if state is not None: + save(state) + write(BUILD) + + dispatch = {} + + def save_none(self, object): + self.write(NONE) + dispatch[NoneType] = save_none + + def save_int(self, object): + if self.bin: + # If the int is small enough to fit in a signed 4-byte 2's-comp + # format, we can store it more efficiently than the general + # case. + high_bits = object >> 31 # note that Python shift sign-extends + if high_bits == 0 or high_bits == -1: + # All high bits are copies of bit 2**31, so the value + # fits in a 4-byte signed int. + i = mdumps(object)[1:] + assert len(i) == 4 + if i[-2:] == '\000\000': # fits in 2-byte unsigned int + if i[-3] == '\000': # fits in 1-byte unsigned int + self.write(BININT1 + i[0]) + else: + self.write(BININT2 + i[:2]) + else: + self.write(BININT + i) + return + # Text pickle, or int too big to fit in signed 4-byte format. + self.write(INT + `object` + '\n') + dispatch[IntType] = save_int + + def save_long(self, object): + self.write(LONG + `object` + '\n') + dispatch[LongType] = save_long + + def save_float(self, object, pack=struct.pack): + if self.bin: + self.write(BINFLOAT + pack('>d', object)) + else: + self.write(FLOAT + `object` + '\n') + dispatch[FloatType] = save_float + + def save_string(self, object): + d = id(object) + memo = self.memo + + if self.bin: + l = len(object) + s = mdumps(l)[1:] + if l < 256: + self.write(SHORT_BINSTRING + s[0] + object) + else: + self.write(BINSTRING + s + object) + else: + self.write(STRING + `object` + '\n') + + memo_len = len(memo) + self.write(self.put(memo_len)) + memo[d] = (memo_len, object) + dispatch[StringType] = save_string + + def save_unicode(self, object): + d = id(object) + memo = self.memo + + if self.bin: + encoding = object.encode('utf-8') + l = len(encoding) + s = mdumps(l)[1:] + self.write(BINUNICODE + s + encoding) + else: + object = object.replace("\\", "\\u005c") + object = object.replace("\n", "\\u000a") + self.write(UNICODE + object.encode('raw-unicode-escape') + '\n') + + memo_len = len(memo) + self.write(self.put(memo_len)) + memo[d] = (memo_len, object) + dispatch[UnicodeType] = save_unicode + + if StringType == UnicodeType: + # This is true for Jython + def save_string(self, object): + d = id(object) + memo = self.memo + unicode = object.isunicode() + + if self.bin: + if unicode: + object = object.encode("utf-8") + l = len(object) + s = mdumps(l)[1:] + if l < 256 and not unicode: + self.write(SHORT_BINSTRING + s[0] + object) + else: + if unicode: + self.write(BINUNICODE + s + object) + else: + self.write(BINSTRING + s + object) + else: + if unicode: + object = object.replace("\\", "\\u005c") + object = object.replace("\n", "\\u000a") + object = object.encode('raw-unicode-escape') + self.write(UNICODE + object + '\n') + else: + self.write(STRING + `object` + '\n') + + memo_len = len(memo) + self.write(self.put(memo_len)) + memo[d] = (memo_len, object) + dispatch[StringType] = save_string + + def save_tuple(self, object): + + write = self.write + save = self.save + memo = self.memo + + d = id(object) + + write(MARK) + + for element in object: + save(element) + + if len(object) and memo.has_key(d): + if self.bin: + write(POP_MARK + self.get(memo[d][0])) + return + + write(POP * (len(object) + 1) + self.get(memo[d][0])) + return + + memo_len = len(memo) + self.write(TUPLE + self.put(memo_len)) + memo[d] = (memo_len, object) + dispatch[TupleType] = save_tuple + + def save_empty_tuple(self, object): + self.write(EMPTY_TUPLE) + + def save_list(self, object): + d = id(object) + + write = self.write + save = self.save + memo = self.memo + + if self.bin: + write(EMPTY_LIST) + else: + write(MARK + LIST) + + memo_len = len(memo) + write(self.put(memo_len)) + memo[d] = (memo_len, object) + + using_appends = (self.bin and (len(object) > 1)) + + if using_appends: + write(MARK) + + for element in object: + save(element) + + if not using_appends: + write(APPEND) + + if using_appends: + write(APPENDS) + dispatch[ListType] = save_list + + def save_dict(self, object): + d = id(object) + + write = self.write + save = self.save + memo = self.memo + + if self.bin: + write(EMPTY_DICT) + else: + write(MARK + DICT) + + memo_len = len(memo) + self.write(self.put(memo_len)) + memo[d] = (memo_len, object) + + using_setitems = (self.bin and (len(object) > 1)) + + if using_setitems: + write(MARK) + + items = object.items() + for key, value in items: + save(key) + save(value) + + if not using_setitems: + write(SETITEM) + + if using_setitems: + write(SETITEMS) + + dispatch[DictionaryType] = save_dict + if not PyStringMap is None: + dispatch[PyStringMap] = save_dict + + def save_inst(self, object): + d = id(object) + cls = object.__class__ + + memo = self.memo + write = self.write + save = self.save + + if hasattr(object, '__getinitargs__'): + args = object.__getinitargs__() + len(args) # XXX Assert it's a sequence + _keep_alive(args, memo) + else: + args = () + + write(MARK) + + if self.bin: + save(cls) + + for arg in args: + save(arg) + + memo_len = len(memo) + if self.bin: + write(OBJ + self.put(memo_len)) + else: + write(INST + cls.__module__ + '\n' + cls.__name__ + '\n' + + self.put(memo_len)) + + memo[d] = (memo_len, object) + + try: + getstate = object.__getstate__ + except AttributeError: + stuff = object.__dict__ + else: + stuff = getstate() + _keep_alive(stuff, memo) + save(stuff) + write(BUILD) + dispatch[InstanceType] = save_inst + + def save_global(self, object, name = None): + write = self.write + memo = self.memo + + if name is None: + name = object.__name__ + + try: + module = object.__module__ + except AttributeError: + module = whichmodule(object, name) + + try: + __import__(module) + mod = sys.modules[module] + klass = getattr(mod, name) + except (ImportError, KeyError, AttributeError): + raise PicklingError( + "Can't pickle %r: it's not found as %s.%s" % + (object, module, name)) + else: + if klass is not object: + raise PicklingError( + "Can't pickle %r: it's not the same object as %s.%s" % + (object, module, name)) + + memo_len = len(memo) + write(GLOBAL + module + '\n' + name + '\n' + + self.put(memo_len)) + memo[id(object)] = (memo_len, object) + dispatch[ClassType] = save_global + dispatch[FunctionType] = save_global + dispatch[BuiltinFunctionType] = save_global + dispatch[TypeType] = save_global + + +def _keep_alive(x, memo): + """Keeps a reference to the object x in the memo. + + Because we remember objects by their id, we have + to assure that possibly temporary objects are kept + alive by referencing them. + We store a reference at the id of the memo, which should + normally not be used unless someone tries to deepcopy + the memo itself... + """ + try: + memo[id(memo)].append(x) + except KeyError: + # aha, this is the first one :-) + memo[id(memo)]=[x] + + +classmap = {} # called classmap for backwards compatibility + +def whichmodule(func, funcname): + """Figure out the module in which a function occurs. + + Search sys.modules for the module. + Cache in classmap. + Return a module name. + If the function cannot be found, return __main__. + """ + if classmap.has_key(func): + return classmap[func] + + for name, module in sys.modules.items(): + if module is None: + continue # skip dummy package entries + if name != '__main__' and \ + hasattr(module, funcname) and \ + getattr(module, funcname) is func: + break + else: + name = '__main__' + classmap[func] = name + return name + + +class Unpickler: + + def __init__(self, file): + self.readline = file.readline + self.read = file.read + self.memo = {} + + def load(self): + self.mark = object() # any new unique object + self.stack = [] + self.append = self.stack.append + read = self.read + dispatch = self.dispatch + try: + while 1: + key = read(1) + dispatch[key](self) + except _Stop, stopinst: + return stopinst.value + + def marker(self): + stack = self.stack + mark = self.mark + k = len(stack)-1 + while stack[k] is not mark: k = k-1 + return k + + dispatch = {} + + def load_eof(self): + raise EOFError + dispatch[''] = load_eof + + def load_persid(self): + pid = self.readline()[:-1] + self.append(self.persistent_load(pid)) + dispatch[PERSID] = load_persid + + def load_binpersid(self): + stack = self.stack + + pid = stack[-1] + del stack[-1] + + self.append(self.persistent_load(pid)) + dispatch[BINPERSID] = load_binpersid + + def load_none(self): + self.append(None) + dispatch[NONE] = load_none + + def load_int(self): + data = self.readline() + try: + self.append(int(data)) + except ValueError: + self.append(long(data)) + dispatch[INT] = load_int + + def load_binint(self): + self.append(mloads('i' + self.read(4))) + dispatch[BININT] = load_binint + + def load_binint1(self): + self.append(mloads('i' + self.read(1) + '\000\000\000')) + dispatch[BININT1] = load_binint1 + + def load_binint2(self): + self.append(mloads('i' + self.read(2) + '\000\000')) + dispatch[BININT2] = load_binint2 + + def load_long(self): + self.append(long(self.readline()[:-1], 0)) + dispatch[LONG] = load_long + + def load_float(self): + self.append(float(self.readline()[:-1])) + dispatch[FLOAT] = load_float + + def load_binfloat(self, unpack=struct.unpack): + self.append(unpack('>d', self.read(8))[0]) + dispatch[BINFLOAT] = load_binfloat + + def load_string(self): + rep = self.readline()[:-1] + if not self._is_string_secure(rep): + raise ValueError, "insecure string pickle" + self.append(eval(rep, + {'__builtins__': {}})) # Let's be careful + dispatch[STRING] = load_string + + def _is_string_secure(self, s): + """Return true if s contains a string that is safe to eval + + The definition of secure string is based on the implementation + in cPickle. s is secure as long as it only contains a quoted + string and optional trailing whitespace. + """ + q = s[0] + if q not in ("'", '"'): + return 0 + # find the closing quote + offset = 1 + i = None + while 1: + try: + i = s.index(q, offset) + except ValueError: + # if there is an error the first time, there is no + # close quote + if offset == 1: + return 0 + if s[i-1] != '\\': + break + # check to see if this one is escaped + nslash = 0 + j = i - 1 + while j >= offset and s[j] == '\\': + j = j - 1 + nslash = nslash + 1 + if nslash % 2 == 0: + break + offset = i + 1 + for c in s[i+1:]: + if ord(c) > 32: + return 0 + return 1 + + def load_binstring(self): + len = mloads('i' + self.read(4)) + self.append(self.read(len)) + dispatch[BINSTRING] = load_binstring + + def load_unicode(self): + self.append(unicode(self.readline()[:-1],'raw-unicode-escape')) + dispatch[UNICODE] = load_unicode + + def load_binunicode(self): + len = mloads('i' + self.read(4)) + self.append(unicode(self.read(len),'utf-8')) + dispatch[BINUNICODE] = load_binunicode + + def load_short_binstring(self): + len = mloads('i' + self.read(1) + '\000\000\000') + self.append(self.read(len)) + dispatch[SHORT_BINSTRING] = load_short_binstring + + def load_tuple(self): + k = self.marker() + self.stack[k:] = [tuple(self.stack[k+1:])] + dispatch[TUPLE] = load_tuple + + def load_empty_tuple(self): + self.stack.append(()) + dispatch[EMPTY_TUPLE] = load_empty_tuple + + def load_empty_list(self): + self.stack.append([]) + dispatch[EMPTY_LIST] = load_empty_list + + def load_empty_dictionary(self): + self.stack.append({}) + dispatch[EMPTY_DICT] = load_empty_dictionary + + def load_list(self): + k = self.marker() + self.stack[k:] = [self.stack[k+1:]] + dispatch[LIST] = load_list + + def load_dict(self): + k = self.marker() + d = {} + items = self.stack[k+1:] + for i in range(0, len(items), 2): + key = items[i] + value = items[i+1] + d[key] = value + self.stack[k:] = [d] + dispatch[DICT] = load_dict + + def load_inst(self): + k = self.marker() + args = tuple(self.stack[k+1:]) + del self.stack[k:] + module = self.readline()[:-1] + name = self.readline()[:-1] + klass = self.find_class(module, name) + instantiated = 0 + if (not args and type(klass) is ClassType and + not hasattr(klass, "__getinitargs__")): + try: + value = _EmptyClass() + value.__class__ = klass + instantiated = 1 + except RuntimeError: + # In restricted execution, assignment to inst.__class__ is + # prohibited + pass + if not instantiated: + try: + #XXX: This test is deprecated in 2.3, so commenting out. + #if not hasattr(klass, '__safe_for_unpickling__'): + # raise UnpicklingError('%s is not safe for unpickling' % + # klass) + value = apply(klass, args) + except TypeError, err: + raise TypeError, "in constructor for %s: %s" % ( + klass.__name__, str(err)), sys.exc_info()[2] + self.append(value) + dispatch[INST] = load_inst + + def load_obj(self): + stack = self.stack + k = self.marker() + klass = stack[k + 1] + del stack[k + 1] + args = tuple(stack[k + 1:]) + del stack[k:] + instantiated = 0 + if (not args and type(klass) is ClassType and + not hasattr(klass, "__getinitargs__")): + try: + value = _EmptyClass() + value.__class__ = klass + instantiated = 1 + except RuntimeError: + # In restricted execution, assignment to inst.__class__ is + # prohibited + pass + if not instantiated: + value = apply(klass, args) + self.append(value) + dispatch[OBJ] = load_obj + + def load_global(self): + module = self.readline()[:-1] + name = self.readline()[:-1] + klass = self.find_class(module, name) + self.append(klass) + dispatch[GLOBAL] = load_global + + def find_class(self, module, name): + __import__(module) + mod = sys.modules[module] + klass = getattr(mod, name) + return klass + + def load_reduce(self): + stack = self.stack + + callable = stack[-2] + arg_tup = stack[-1] + del stack[-2:] + + #XXX: The __safe_for_unpickling__ test is deprecated in 2.3, so commenting out. + #if type(callable) is not ClassType: + #if not safe_constructors.has_key(callable): + #try: + # safe = callable.__safe_for_unpickling__ + #except AttributeError: + # safe = None + # + #if not safe: + # raise UnpicklingError, "%s is not safe for " \ + # "unpickling" % callable + + if arg_tup is None: + value = callable.__basicnew__() + else: + value = apply(callable, arg_tup) + self.append(value) + dispatch[REDUCE] = load_reduce + + def load_pop(self): + del self.stack[-1] + dispatch[POP] = load_pop + + def load_pop_mark(self): + k = self.marker() + del self.stack[k:] + dispatch[POP_MARK] = load_pop_mark + + def load_dup(self): + self.append(self.stack[-1]) + dispatch[DUP] = load_dup + + def load_get(self): + self.append(self.memo[self.readline()[:-1]]) + dispatch[GET] = load_get + + def load_binget(self): + i = mloads('i' + self.read(1) + '\000\000\000') + self.append(self.memo[`i`]) + dispatch[BINGET] = load_binget + + def load_long_binget(self): + i = mloads('i' + self.read(4)) + self.append(self.memo[`i`]) + dispatch[LONG_BINGET] = load_long_binget + + def load_put(self): + self.memo[self.readline()[:-1]] = self.stack[-1] + dispatch[PUT] = load_put + + def load_binput(self): + i = mloads('i' + self.read(1) + '\000\000\000') + self.memo[`i`] = self.stack[-1] + dispatch[BINPUT] = load_binput + + def load_long_binput(self): + i = mloads('i' + self.read(4)) + self.memo[`i`] = self.stack[-1] + dispatch[LONG_BINPUT] = load_long_binput + + def load_append(self): + stack = self.stack + value = stack[-1] + del stack[-1] + list = stack[-1] + list.append(value) + dispatch[APPEND] = load_append + + def load_appends(self): + stack = self.stack + mark = self.marker() + list = stack[mark - 1] + for i in range(mark + 1, len(stack)): + list.append(stack[i]) + + del stack[mark:] + dispatch[APPENDS] = load_appends + + def load_setitem(self): + stack = self.stack + value = stack[-1] + key = stack[-2] + del stack[-2:] + dict = stack[-1] + dict[key] = value + dispatch[SETITEM] = load_setitem + + def load_setitems(self): + stack = self.stack + mark = self.marker() + dict = stack[mark - 1] + for i in range(mark + 1, len(stack), 2): + dict[stack[i]] = stack[i + 1] + + del stack[mark:] + dispatch[SETITEMS] = load_setitems + + def load_build(self): + stack = self.stack + value = stack[-1] + del stack[-1] + inst = stack[-1] + try: + setstate = inst.__setstate__ + except AttributeError: + try: + inst.__dict__.update(value) + except RuntimeError: + # XXX In restricted execution, the instance's __dict__ is not + # accessible. Use the old way of unpickling the instance + # variables. This is a semantic different when unpickling in + # restricted vs. unrestricted modes. + for k, v in value.items(): + setattr(inst, k, v) + else: + setstate(value) + dispatch[BUILD] = load_build + + def load_mark(self): + self.append(self.mark) + dispatch[MARK] = load_mark + + def load_stop(self): + value = self.stack[-1] + del self.stack[-1] + raise _Stop(value) + dispatch[STOP] = load_stop + +# Helper class for load_inst/load_obj + +class _EmptyClass: + pass + +# Shorthands + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +def dump(object, file, bin = 0): + Pickler(file, bin).dump(object) + +def dumps(object, bin = 0): + file = StringIO() + Pickler(file, bin).dump(object) + return file.getvalue() + +def load(file): + return Unpickler(file).load() + +def loads(str): + file = StringIO(str) + return Unpickler(file).load() Modified: trunk/jython/Lib/test/test_datetime.py =================================================================== --- trunk/jython/Lib/test/test_datetime.py 2006-05-24 17:16:28 UTC (rev 2759) +++ trunk/jython/Lib/test/test_datetime.py 2006-05-26 02:15:40 UTC (rev 2760) @@ -15,11 +15,12 @@ from datetime import date, datetime # Before Python 2.3, proto=2 was taken as a synonym for proto=1. +# cPickle not updated in Jython so commenting out. pickle_choices = [(pickler, unpickler, proto) - for pickler in pickle, cPickle - for unpickler in pickle, cPickle + for pickler in pickle, #cPickle + for unpickler in pickle, #cPickle for proto in range(3)] -assert len(pickle_choices) == 2*2*3 +#assert len(pickle_choices) == 2*2*3 # An arbitrary collection of objects of non-datetime types, for testing # mixed-type comparisons. @@ -96,7 +97,7 @@ self.assertEqual(fo.dst(dt), timedelta(minutes=42)) #XXX: pickling not working for jython yet. - def _test_pickling_base(self): + def test_pickling_base(self): # There's no point to pickling tzinfo objects on their own (they # carry no data), but they need to be picklable anyway else # concrete subclasses can't be pickled. @@ -108,7 +109,7 @@ self.failUnless(type(derived) is tzinfo) #XXX: pickling not working for jython yet. - def _test_pickling_subclass(self): + def test_pickling_subclass(self): # Make sure we can pickle/unpickle an instance of a subclass. offset = timedelta(minutes=-300) orig = PicklableFixedOffset(offset, 'cookie') @@ -298,7 +299,7 @@ self.assertEqual(d[t1], 2) #XXX: pickling not working for jython yet. - def _test_pickling(self): + def test_pickling(self): args = 12, 34, 56 orig = timedelta(*args) for pickler, unpickler, proto in pickle_choices: @@ -852,7 +853,7 @@ self.assertEqual(t.tm_isdst, -1) #XXX: pickling not working for jython yet. - def _test_pickling(self): + def test_pickling(self): args = 6, 7, 23 orig = self.theclass(*args) for pickler, unpickler, proto in pickle_choices: @@ -1205,7 +1206,7 @@ self.assertRaises(TypeError, lambda: a + a) #XXX: pickling not working for jython yet. - def _test_pickling(self): + def test_pickling(self): args = 6, 7, 23, 20, 59, 1, 64**2 orig = self.theclass(*args) for pickler, unpickler, proto in pickle_choices: @@ -1214,7 +1215,7 @@ self.assertEqual(orig, derived) #XXX: pickling not working for jython yet. - def _test_more_pickling(self): + def test_more_pickling(self): a = self.theclass(2003, 2, 7, 16, 48, 37, 444116) s = pickle.dumps(a) b = pickle.loads(s) @@ -1592,7 +1593,7 @@ self.assert_(self.theclass.max > self.theclass.min) #XXX: pickling not working for jython yet. - def _test_pickling(self): + def test_pickling(self): args = 20, 59, 16, 64**2 orig = self.theclass(*args) for pickler, unpickler, proto in pickle_choices: @@ -1900,7 +1901,7 @@ self.assertEqual(hash(t1), hash(t2)) #XXX: pickling not working for jython yet. - def _test_pickling(self): + def test_pickling(self): # Try one without a tzinfo. args = 20, 59, 16, 64**2 orig = self.theclass(*args) @@ -2101,7 +2102,7 @@ self.assertRaises(ValueError, lambda: t1 == t2) #XXX: pickling not working for jython yet. - def _test_pickling(self): + def test_pickling(self): # Try one without a tzinfo. args = 6, 7, 23, 20, 59, 1, 64**2 orig = self.theclass(*args) Modified: trunk/jython/Lib/test/test_subclasses.py =================================================================== --- trunk/jython/Lib/test/test_subclasses.py 2006-05-24 17:16:28 UTC (rev 2759) +++ trunk/jython/Lib/test/test_subclasses.py 2006-05-26 02:15:40 UTC (rev 2760) @@ -128,7 +128,12 @@ self.assertEqual(SubSubStrSpam().eggs(), "I am eggs.") self.assertEqual(SubSubStrSpam2().eggs(), "I am eggs.") + def test_pickle_builtins(self): + class myint(int): + def __init__(self, x): + self.str = str(x) + def test_suite(): allsuites = [unittest.makeSuite(klass, 'test') for klass in (TestSubclasses, Modified: trunk/jython/src/org/python/core/PyInteger.java =================================================================== --- trunk/jython/src/org/python/core/PyInteger.java 2006-05-24 17:16:28 UTC (rev 2759) +++ trunk/jython/src/org/python/core/PyInteger.java 2006-05-26 02:15:40 UTC (rev 2760) @@ -1058,6 +1058,34 @@ } dict.__setitem__("__nonzero__",new PyMethodDescr("__nonzero__",PyInteger.class,0,0,new exposed___nonzero__(null,null))); + class exposed___reduce__ extends PyBuiltinFunctionNarrow { + + private PyInteger self; + + public PyObject getSelf() { + return self; + } + + exposed___reduce__(PyInteger self,PyBuiltinFunction.Info info) { + super(info); + this.self=self; + } + + public PyBuiltinFunction makeBound(PyObject self) { + return new exposed___reduce__((PyInteger)self,info); + } + + public PyObject __call__() { + return self.int___reduce__(); + } + + public PyObject inst_call(PyObject gself) { + PyInteger self=(PyInteger)gself; + return self.int___reduce__(); + } + + } + dict.__setitem__("__reduce__",new PyMethodDescr("__reduce__",PyInteger.class,0,0,new exposed___reduce__(null,null))); class exposed___repr__ extends PyBuiltinFunctionNarrow { private PyInteger self; @@ -1782,4 +1810,22 @@ public int asInt(int index) throws PyObject.ConversionException { return getValue(); } + + /** + * Used for pickling. + * + * @return a tuple of (class, (Integer)) + */ + public PyObject __reduce__() { + return int___reduce__(); + } + + public PyObject int___reduce__() { + return new PyTuple(new PyObject[]{ + getType(), + new PyTuple(new PyObject[]{ + new PyString(int_toString()) + }) + }); + } } Modified: trunk/jython/src/org/python/core/PyIntegerDerived.java =================================================================== --- trunk/jython/src/org/python/core/PyIntegerDerived.java 2006-05-24 17:16:28 UTC (rev 2759) +++ trunk/jython/src/org/python/core/PyIntegerDerived.java 2006-05-26 02:15:40 UTC (rev 2760) @@ -12,9 +12,8 @@ return dict; } - //XXX: hand modified to pass v into super - public PyIntegerDerived(PyType subtype, int v) { - super(subtype, v); + public PyIntegerDerived(PyType subtype,int v) { + super(subtype,v); dict=subtype.instDict(); } @@ -42,6 +41,18 @@ return super.__repr__(); } + public PyUnicode __unicode__() { + PyType self_type=getType(); + PyObject impl=self_type.lookup("__unicode__"); + if (impl!=null) { + PyObject res=impl.__get__(this,self_type).__call__(); + if (res instanceof PyUnicode) + return(PyUnicode)res; + throw Py.TypeError("__unicode__"+" should return a "+"unicode"); + } + return super.__unicode__(); + } + public PyString __hex__() { PyType self_type=getType(); PyObject impl=self_type.lookup("__hex__"); @@ -146,6 +157,19 @@ return super.__invert__(); } + public PyObject __reduce__() { + PyType self_type=getType(); + PyObject impl=self_type.lookup("__reduce__"); + if (impl!=null) + return impl.__get__(this,self_type).__call__(); + return new PyTuple(new PyObject[]{ + getType(), + new PyTuple(new PyObject[]{ + new PyString("x") + }) + }); + } + public PyObject __add__(PyObject other) { PyType self_type=getType(); PyObject impl=self_type.lookup("__add__"); @@ -779,6 +803,35 @@ super.__setitem__(key,value); } + public PyObject __getitem__(PyObject key) { // ??? + PyType self_type=getType(); + PyObject impl=self_type.lookup("__getitem__"); + if (impl!=null) { + try { + return impl.__get__(this,self_type).__call__(key); + } catch (PyException exc) { + if (Py.matchException(exc,Py.LookupError)) + return null; + throw exc; + } + } + return super.__getitem__(key); + } + + public PyObject __getslice__(PyObject start,PyObject stop,PyObject step) { // ??? + PyType self_type=getType(); + PyObject impl=self_type.lookup("__getslice__"); + if (impl!=null) + try { + return impl.__get__(this,self_type).__call__(start,stop); + } catch (PyException exc) { + if (Py.matchException(exc,Py.LookupError)) + return null; + throw exc; + } + return super.__getslice__(start,stop,step); + } + public void __delitem__(PyObject key) { // ??? PyType self_type=getType(); PyObject impl=self_type.lookup("__delitem__"); This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <cg...@us...> - 2006-08-31 03:13:18
|
Revision: 2920 http://svn.sourceforge.net/jython/?rev=2920&view=rev Author: cgroves Date: 2006-08-30 20:13:09 -0700 (Wed, 30 Aug 2006) Log Message: ----------- initial import of core pyxml Added Paths: ----------- trunk/jython/Lib/xml/ trunk/jython/Lib/xml/FtCore.py trunk/jython/Lib/xml/Uri.py trunk/jython/Lib/xml/__init__.py trunk/jython/Lib/xml/dom/ trunk/jython/Lib/xml/dom/MessageSource.py trunk/jython/Lib/xml/dom/NodeFilter.py trunk/jython/Lib/xml/dom/__init__.py trunk/jython/Lib/xml/dom/domreg.py trunk/jython/Lib/xml/dom/minicompat.py trunk/jython/Lib/xml/dom/minidom.py trunk/jython/Lib/xml/dom/pulldom.py trunk/jython/Lib/xml/dom/xmlbuilder.py trunk/jython/Lib/xml/sax/ trunk/jython/Lib/xml/sax/__init__.py trunk/jython/Lib/xml/sax/_exceptions.py trunk/jython/Lib/xml/sax/drivers2/ trunk/jython/Lib/xml/sax/drivers2/__init__.py trunk/jython/Lib/xml/sax/drivers2/drv_javasax.py trunk/jython/Lib/xml/sax/handler.py trunk/jython/Lib/xml/sax/saxlib.py trunk/jython/Lib/xml/sax/saxutils.py trunk/jython/Lib/xml/sax/xmlreader.py Added: trunk/jython/Lib/xml/FtCore.py =================================================================== --- trunk/jython/Lib/xml/FtCore.py (rev 0) +++ trunk/jython/Lib/xml/FtCore.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,58 @@ +""" +Contains various definitions common to modules acquired from 4Suite +""" + +__all__ = ["FtException", "get_translator"] + + +class FtException(Exception): + def __init__(self, errorCode, messages, args): + # By defining __str__, args will be available. Otherwise + # the __init__ of Exception sets it to the passed in arguments. + self.params = args + self.errorCode = errorCode + self.message = messages[errorCode] % args + Exception.__init__(self, self.message, args) + + def __str__(self): + return self.message + + +# What follows is used to provide support for I18N in the rest of the +# 4Suite-derived packages in PyXML. +# +# Each sub-package of the top-level "xml" package that contains 4Suite +# code is really a separate text domain, but they're all called +# '4Suite'. For each domain, a translation object is provided using +# message catalogs stored inside the package. The code below defines +# a get_translator() function that returns an appropriate gettext +# function to be used as _() in the sub-package named by the +# parameter. This handles all the compatibility issues related to +# Python versions (whether the gettext module can be found) and +# whether the message catalogs can actually be found. + +def _(msg): + return msg + +try: + import gettext + +except (ImportError, IOError): + def get_translator(pkg): + return _ + +else: + import os + + _cache = {} + _top = os.path.dirname(os.path.abspath(__file__)) + + def get_translator(pkg): + if not _cache.has_key(pkg): + locale_dir = os.path.join(_top, pkg.replace(".", os.sep)) + try: + f = gettext.translation('4Suite', locale_dir).gettext + except IOError: + f = _ + _cache[pkg] = f + return _cache[pkg] Added: trunk/jython/Lib/xml/Uri.py =================================================================== --- trunk/jython/Lib/xml/Uri.py (rev 0) +++ trunk/jython/Lib/xml/Uri.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,380 @@ +# pylint: disable-msg=C0103 +# +# backported code from 4Suite with slight modifications, started from r1.89 of +# Ft/Lib/Uri.py, by sy...@lo... on 2005-02-09 +# +# part if not all of this code should probably move to urlparse (or be used +# to fix some existant functions in this module) +# +# +# Copyright 2004 Fourthought, Inc. (USA). +# Detailed license and copyright information: http://4suite.org/COPYRIGHT +# Project home, documentation, distributions: http://4suite.org/ +import os.path +import sys +import re +import urlparse, urllib, urllib2 + +def UnsplitUriRef(uriRefSeq): + """should replace urlparse.urlunsplit + + Given a sequence as would be produced by SplitUriRef(), assembles and + returns a URI reference as a string. + """ + if not isinstance(uriRefSeq, (tuple, list)): + raise TypeError("sequence expected, got %s" % type(uriRefSeq)) + (scheme, authority, path, query, fragment) = uriRefSeq + uri = '' + if scheme is not None: + uri += scheme + ':' + if authority is not None: + uri += '//' + authority + uri += path + if query is not None: + uri += '?' + query + if fragment is not None: + uri += '#' + fragment + return uri + +SPLIT_URI_REF_PATTERN = re.compile(r"^(?:(?P<scheme>[^:/?#]+):)?(?://(?P<authority>[^/?#]*))?(?P<path>[^?#]*)(?:\?(?P<query>[^#]*))?(?:#(?P<fragment>.*))?$") + +def SplitUriRef(uriref): + """should replace urlparse.urlsplit + + Given a valid URI reference as a string, returns a tuple representing the + generic URI components, as per RFC 2396 appendix B. The tuple's structure + is (scheme, authority, path, query, fragment). + + All values will be strings (possibly empty) or None if undefined. + + Note that per rfc3986, there is no distinction between a path and + an "opaque part", as there was in RFC 2396. + """ + # the pattern will match every possible string, so it's safe to + # assume there's a groupdict method to call. + g = SPLIT_URI_REF_PATTERN.match(uriref).groupdict() + scheme = g['scheme'] + authority = g['authority'] + path = g['path'] + query = g['query'] + fragment = g['fragment'] + return (scheme, authority, path, query, fragment) + + +def Absolutize(uriRef, baseUri): + """ + Resolves a URI reference to absolute form, effecting the result of RFC + 3986 section 5. The URI reference is considered to be relative to the + given base URI. + + It is the caller's responsibility to ensure that the base URI matches + the absolute-URI syntax rule of RFC 3986, and that its path component + does not contain '.' or '..' segments if the scheme is hierarchical. + Unexpected results may occur otherwise. + + This function only conducts a minimal sanity check in order to determine + if relative resolution is possible: it raises a UriException if the base + URI does not have a scheme component. While it is true that the base URI + is irrelevant if the URI reference has a scheme, an exception is raised + in order to signal that the given string does not even come close to + meeting the criteria to be usable as a base URI. + + It is the caller's responsibility to make a determination of whether the + URI reference constitutes a "same-document reference", as defined in RFC + 2396 or RFC 3986. As per the spec, dereferencing a same-document + reference "should not" involve retrieval of a new representation of the + referenced resource. Note that the two specs have different definitions + of same-document reference: RFC 2396 says it is *only* the cases where the + reference is the empty string, or "#" followed by a fragment; RFC 3986 + requires making a comparison of the base URI to the absolute form of the + reference (as is returned by the spec), minus its fragment component, + if any. + + This function is similar to urlparse.urljoin() and urllib.basejoin(). + Those functions, however, are (as of Python 2.3) outdated, buggy, and/or + designed to produce results acceptable for use with other core Python + libraries, rather than being earnest implementations of the relevant + specs. Their problems are most noticeable in their handling of + same-document references and 'file:' URIs, both being situations that + come up far too often to consider the functions reliable enough for + general use. + """ + # Reasons to avoid using urllib.basejoin() and urlparse.urljoin(): + # - Both are partial implementations of long-obsolete specs. + # - Both accept relative URLs as the base, which no spec allows. + # - urllib.basejoin() mishandles the '' and '..' references. + # - If the base URL uses a non-hierarchical or relative path, + # or if the URL scheme is unrecognized, the result is not + # always as expected (partly due to issues in RFC 1808). + # - If the authority component of a 'file' URI is empty, + # the authority component is removed altogether. If it was + # not present, an empty authority component is in the result. + # - '.' and '..' segments are not always collapsed as well as they + # should be (partly due to issues in RFC 1808). + # - Effective Python 2.4, urllib.basejoin() *is* urlparse.urljoin(), + # but urlparse.urljoin() is still based on RFC 1808. + + # This procedure is based on the pseudocode in RFC 3986 sec. 5.2. + # + # ensure base URI is absolute + if not baseUri: + raise ValueError('baseUri is required and must be a non empty string') + if not IsAbsolute(baseUri): + raise ValueError('%r is not an absolute URI' % baseUri) + # shortcut for the simplest same-document reference cases + if uriRef == '' or uriRef[0] == '#': + return baseUri.split('#')[0] + uriRef + # ensure a clean slate + tScheme = tAuth = tPath = tQuery = None + # parse the reference into its components + (rScheme, rAuth, rPath, rQuery, rFrag) = SplitUriRef(uriRef) + # if the reference is absolute, eliminate '.' and '..' path segments + # and skip to the end + if rScheme is not None: + tScheme = rScheme + tAuth = rAuth + tPath = RemoveDotSegments(rPath) + tQuery = rQuery + else: + # the base URI's scheme, and possibly more, will be inherited + (bScheme, bAuth, bPath, bQuery, bFrag) = SplitUriRef(baseUri) + # if the reference is a net-path, just eliminate '.' and '..' path + # segments; no other changes needed. + if rAuth is not None: + tAuth = rAuth + tPath = RemoveDotSegments(rPath) + tQuery = rQuery + # if it's not a net-path, we need to inherit pieces of the base URI + else: + # use base URI's path if the reference's path is empty + if not rPath: + tPath = bPath + # use the reference's query, if any, or else the base URI's, + tQuery = rQuery is not None and rQuery or bQuery + # the reference's path is not empty + else: + # just use the reference's path if it's absolute + if rPath[0] == '/': + tPath = RemoveDotSegments(rPath) + # merge the reference's relative path with the base URI's path + else: + if bAuth is not None and not bPath: + tPath = '/' + rPath + else: + tPath = bPath[:bPath.rfind('/')+1] + rPath + tPath = RemoveDotSegments(tPath) + # use the reference's query + tQuery = rQuery + # since the reference isn't a net-path, + # use the authority from the base URI + tAuth = bAuth + # inherit the scheme from the base URI + tScheme = bScheme + # always use the reference's fragment (but no need to define another var) + #tFrag = rFrag + + # now compose the target URI (RFC 3986 sec. 5.3) + return UnsplitUriRef((tScheme, tAuth, tPath, tQuery, rFrag)) + + +REG_NAME_HOST_PATTERN = re.compile(r"^(?:(?:[0-9A-Za-z\-_\.!~*'();&=+$,]|(?:%[0-9A-Fa-f]{2}))*)$") + +def MakeUrllibSafe(uriRef): + """ + Makes the given RFC 3986-conformant URI reference safe for passing + to legacy urllib functions. The result may not be a valid URI. + + As of Python 2.3.3, urllib.urlopen() does not fully support + internationalized domain names, it does not strip fragment components, + and on Windows, it expects file URIs to use '|' instead of ':' in the + path component corresponding to the drivespec. It also relies on + urllib.unquote(), which mishandles unicode arguments. This function + produces a URI reference that will work around these issues, although + the IDN workaround is limited to Python 2.3 only. May raise a + UnicodeEncodeError if the URI reference is Unicode and erroneously + contains non-ASCII characters. + """ + # IDN support requires decoding any percent-encoded octets in the + # host part (if it's a reg-name) of the authority component, and when + # doing DNS lookups, applying IDNA encoding to that string first. + # As of Python 2.3, there is an IDNA codec, and the socket and httplib + # modules accept Unicode strings and apply IDNA encoding automatically + # where necessary. However, urllib.urlopen() has not yet been updated + # to do the same; it raises an exception if you give it a Unicode + # string, and does no conversion on non-Unicode strings, meaning you + # have to give it an IDNA string yourself. We will only support it on + # Python 2.3 and up. + # + # see if host is a reg-name, as opposed to IPv4 or IPv6 addr. + if isinstance(uriRef, unicode): + try: + uriRef = uriRef.encode('us-ascii') # parts of urllib are not unicode safe + except UnicodeError: + raise ValueError("uri %r must consist of ASCII characters." % uriRef) + (scheme, auth, path, query, frag) = urlparse.urlsplit(uriRef) + if auth and auth.find('@') > -1: + userinfo, hostport = auth.split('@') + else: + userinfo = None + hostport = auth + if hostport and hostport.find(':') > -1: + host, port = hostport.split(':') + else: + host = hostport + port = None + if host and REG_NAME_HOST_PATTERN.match(host): + # percent-encoded hostnames will always fail DNS lookups + host = urllib.unquote(host) #PercentDecode(host) + # IDNA-encode if possible. + # We shouldn't do this for schemes that don't need DNS lookup, + # but are there any (that you'd be calling urlopen for)? + if sys.version_info[0:2] >= (2, 3): + if isinstance(host, str): + host = host.decode('utf-8') + host = host.encode('idna') + # reassemble the authority with the new hostname + # (percent-decoded, and possibly IDNA-encoded) + auth = '' + if userinfo: + auth += userinfo + '@' + auth += host + if port: + auth += ':' + port + + # On Windows, ensure that '|', not ':', is used in a drivespec. + if os.name == 'nt' and scheme == 'file': + path = path.replace(':', '|', 1) + + # Note that we drop fragment, if any. See RFC 3986 sec. 3.5. + uri = urlparse.urlunsplit((scheme, auth, path, query, None)) + + return uri + + + +def BaseJoin(base, uriRef): + """ + Merges a base URI reference with another URI reference, returning a + new URI reference. + + It behaves exactly the same as Absolutize(), except the arguments + are reversed, and it accepts any URI reference (even a relative URI) + as the base URI. If the base has no scheme component, it is + evaluated as if it did, and then the scheme component of the result + is removed from the result, unless the uriRef had a scheme. Thus, if + neither argument has a scheme component, the result won't have one. + + This function is named BaseJoin because it is very much like + urllib.basejoin(), but it follows the current rfc3986 algorithms + for path merging, dot segment elimination, and inheritance of query + and fragment components. + + WARNING: This function exists for 2 reasons: (1) because of a need + within the 4Suite repository to perform URI reference absolutization + using base URIs that are stored (inappropriately) as absolute paths + in the subjects of statements in the RDF model, and (2) because of + a similar need to interpret relative repo paths in a 4Suite product + setup.xml file as being relative to a path that can be set outside + the document. When these needs go away, this function probably will, + too, so it is not advisable to use it. + """ + if IsAbsolute(base): + return Absolutize(uriRef, base) + else: + dummyscheme = 'basejoin' + res = Absolutize(uriRef, '%s:%s' % (dummyscheme, base)) + if IsAbsolute(uriRef): + # scheme will be inherited from uriRef + return res + else: + # no scheme in, no scheme out + return res[len(dummyscheme)+1:] + + +def RemoveDotSegments(path): + """ + Supports Absolutize() by implementing the remove_dot_segments function + described in RFC 3986 sec. 5.2. It collapses most of the '.' and '..' + segments out of a path without eliminating empty segments. It is intended + to be used during the path merging process and may not give expected + results when used independently. Use NormalizePathSegments() or + NormalizePathSegmentsInUri() if more general normalization is desired. + + semi-private because it is not for general use. I've implemented it + using two segment stacks, as alluded to in the spec, rather than the + explicit string-walking algorithm that would be too inefficient. (mbrown) + """ + # return empty string if entire path is just "." or ".." + if path == '.' or path == '..': + return path[0:0] # preserves string type + # remove all "./" or "../" segments at the beginning + while path: + if path[:2] == './': + path = path[2:] + elif path[:3] == '../': + path = path[3:] + else: + break + # We need to keep track of whether there was a leading slash, + # because we're going to drop it in order to prevent our list of + # segments from having an ambiguous empty first item when we call + # split(). + leading_slash = 0 + if path[:1] == '/': + path = path[1:] + leading_slash = 1 + # replace a trailing "/." with just "/" + if path[-2:] == '/.': + path = path[:-1] + # convert the segments into a list and process each segment in + # order from left to right. + segments = path.split('/') + keepers = [] + segments.reverse() + while segments: + seg = segments.pop() + # '..' means drop the previous kept segment, if any. + # If none, and if the path is relative, then keep the '..'. + # If the '..' was the last segment, ensure + # that the result ends with '/'. + if seg == '..': + if keepers: + keepers.pop() + elif not leading_slash: + keepers.append(seg) + if not segments: + keepers.append('') + # ignore '.' segments and keep all others, even empty ones + elif seg != '.': + keepers.append(seg) + # reassemble the kept segments + return leading_slash * '/' + '/'.join(keepers) + + +SCHEME_PATTERN = re.compile(r'([a-zA-Z][a-zA-Z0-9+\-.]*):') +def GetScheme(uriRef): + """ + Obtains, with optimum efficiency, just the scheme from a URI reference. + Returns a string, or if no scheme could be found, returns None. + """ + # Using a regex seems to be the best option. Called 50,000 times on + # different URIs, on a 1.0-GHz PIII with FreeBSD 4.7 and Python + # 2.2.1, this method completed in 0.95s, and 0.05s if there was no + # scheme to find. By comparison, + # urllib.splittype()[0] took 1.5s always; + # Ft.Lib.Uri.SplitUriRef()[0] took 2.5s always; + # urlparse.urlparse()[0] took 3.5s always. + m = SCHEME_PATTERN.match(uriRef) + if m is None: + return None + else: + return m.group(1) + + +def IsAbsolute(identifier): + """ + Given a string believed to be a URI or URI reference, tests that it is + absolute (as per RFC 2396), not relative -- i.e., that it has a scheme. + """ + # We do it this way to avoid compiling another massive regex. + return GetScheme(identifier) is not None Added: trunk/jython/Lib/xml/__init__.py =================================================================== --- trunk/jython/Lib/xml/__init__.py (rev 0) +++ trunk/jython/Lib/xml/__init__.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,41 @@ +"""Core XML support for Jython. + +This package contains two sub-packages: + +dom -- The W3C Document Object Model. This supports DOM Level 1 + + Namespaces. + +sax -- The Simple API for XML, developed by XML-Dev, led by David + Megginson and ported to Python by Lars Marius Garshol. This + supports the SAX 2 API. + +""" + +__all__ = ['dom', 'sax'] + +# When being checked-out without options, this has the form +# "<dollar>Revision: x.y </dollar>" +# When exported using -kv, it is "x.y". +__version__ = "$Revision$".split()[-2:][0] + + +_MINIMUM_XMLPLUS_VERSION = (0, 8, 5) + + +try: + import _xmlplus +except ImportError: + pass +else: + try: + v = _xmlplus.version_info + except AttributeError: + # _xmlplus is too old; ignore it + pass + else: + if v >= _MINIMUM_XMLPLUS_VERSION: + import sys + _xmlplus.__path__.extend(__path__) + sys.modules[__name__] = _xmlplus + else: + del v Property changes on: trunk/jython/Lib/xml/__init__.py ___________________________________________________________________ Name: svn:keywords + Revision Added: trunk/jython/Lib/xml/dom/MessageSource.py =================================================================== --- trunk/jython/Lib/xml/dom/MessageSource.py (rev 0) +++ trunk/jython/Lib/xml/dom/MessageSource.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,54 @@ +# DOMException +from xml.dom import INDEX_SIZE_ERR, DOMSTRING_SIZE_ERR , HIERARCHY_REQUEST_ERR +from xml.dom import WRONG_DOCUMENT_ERR, INVALID_CHARACTER_ERR, NO_DATA_ALLOWED_ERR +from xml.dom import NO_MODIFICATION_ALLOWED_ERR, NOT_FOUND_ERR, NOT_SUPPORTED_ERR +from xml.dom import INUSE_ATTRIBUTE_ERR, INVALID_STATE_ERR, SYNTAX_ERR +from xml.dom import INVALID_MODIFICATION_ERR, NAMESPACE_ERR, INVALID_ACCESS_ERR +from xml.dom import VALIDATION_ERR + +# EventException +from xml.dom import UNSPECIFIED_EVENT_TYPE_ERR + +#Range Exceptions +from xml.dom import BAD_BOUNDARYPOINTS_ERR +from xml.dom import INVALID_NODE_TYPE_ERR + +# Fourthought Exceptions +from xml.dom import XML_PARSE_ERR + +from xml.FtCore import get_translator + +_ = get_translator("dom") + + +DOMExceptionStrings = { + INDEX_SIZE_ERR: _("Index error accessing NodeList or NamedNodeMap"), + DOMSTRING_SIZE_ERR: _("DOMString exceeds maximum size"), + HIERARCHY_REQUEST_ERR: _("Node manipulation results in invalid parent/child relationship."), + WRONG_DOCUMENT_ERR: _("Node is from a different document"), + INVALID_CHARACTER_ERR: _("Invalid or illegal character"), + NO_DATA_ALLOWED_ERR: _("Node does not support data"), + NO_MODIFICATION_ALLOWED_ERR: _("Attempt to modify a read-only object"), + NOT_FOUND_ERR: _("Node does not exist in this context"), + NOT_SUPPORTED_ERR: _("Object or operation not supported"), + INUSE_ATTRIBUTE_ERR: _("Attribute already in use by an element"), + INVALID_STATE_ERR: _("Object is not, or is no longer, usable"), + SYNTAX_ERR: _("Specified string is invalid or illegal"), + INVALID_MODIFICATION_ERR: _("Attempt to modify the type of a node"), + NAMESPACE_ERR: _("Invalid or illegal namespace operation"), + INVALID_ACCESS_ERR: _("Object does not support this operation or parameter"), + VALIDATION_ERR: _("Operation would invalidate partial validity constraint"), + } + +EventExceptionStrings = { + UNSPECIFIED_EVENT_TYPE_ERR : _("Uninitialized type in Event object"), + } + +FtExceptionStrings = { + XML_PARSE_ERR : _("XML parse error at line %d, column %d: %s"), + } + +RangeExceptionStrings = { + BAD_BOUNDARYPOINTS_ERR : _("Invalid Boundary Points specified for Range"), + INVALID_NODE_TYPE_ERR : _("Invalid Container Node") + } Added: trunk/jython/Lib/xml/dom/NodeFilter.py =================================================================== --- trunk/jython/Lib/xml/dom/NodeFilter.py (rev 0) +++ trunk/jython/Lib/xml/dom/NodeFilter.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,27 @@ +# This is the Python mapping for interface NodeFilter from +# DOM2-Traversal-Range. It contains only constants. + +class NodeFilter: + """ + This is the DOM2 NodeFilter interface. It contains only constants. + """ + FILTER_ACCEPT = 1 + FILTER_REJECT = 2 + FILTER_SKIP = 3 + + SHOW_ALL = 0xFFFFFFFFL + SHOW_ELEMENT = 0x00000001 + SHOW_ATTRIBUTE = 0x00000002 + SHOW_TEXT = 0x00000004 + SHOW_CDATA_SECTION = 0x00000008 + SHOW_ENTITY_REFERENCE = 0x00000010 + SHOW_ENTITY = 0x00000020 + SHOW_PROCESSING_INSTRUCTION = 0x00000040 + SHOW_COMMENT = 0x00000080 + SHOW_DOCUMENT = 0x00000100 + SHOW_DOCUMENT_TYPE = 0x00000200 + SHOW_DOCUMENT_FRAGMENT = 0x00000400 + SHOW_NOTATION = 0x00000800 + + def acceptNode(self, node): + raise NotImplementedError Added: trunk/jython/Lib/xml/dom/__init__.py =================================================================== --- trunk/jython/Lib/xml/dom/__init__.py (rev 0) +++ trunk/jython/Lib/xml/dom/__init__.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,232 @@ +######################################################################## +# +# File Name: __init__.py +# +# +""" +WWW: http://4suite.org/4DOM e-mail: su...@4s... + +Copyright (c) 2000 Fourthought Inc, USA. All Rights Reserved. +See http://4suite.org/COPYRIGHT for license and copyright information +""" + + +class Node: + """Class giving the nodeType and tree-position constants.""" + + # DOM implementations may use this as a base class for their own + # Node implementations. If they don't, the constants defined here + # should still be used as the canonical definitions as they match + # the values given in the W3C recommendation. Client code can + # safely refer to these values in all tests of Node.nodeType + # values. + + ELEMENT_NODE = 1 + ATTRIBUTE_NODE = 2 + TEXT_NODE = 3 + CDATA_SECTION_NODE = 4 + ENTITY_REFERENCE_NODE = 5 + ENTITY_NODE = 6 + PROCESSING_INSTRUCTION_NODE = 7 + COMMENT_NODE = 8 + DOCUMENT_NODE = 9 + DOCUMENT_TYPE_NODE = 10 + DOCUMENT_FRAGMENT_NODE = 11 + NOTATION_NODE = 12 + + # Based on DOM Level 3 (WD 9 April 2002) + + TREE_POSITION_PRECEDING = 0x01 + TREE_POSITION_FOLLOWING = 0x02 + TREE_POSITION_ANCESTOR = 0x04 + TREE_POSITION_DESCENDENT = 0x08 + TREE_POSITION_EQUIVALENT = 0x10 + TREE_POSITION_SAME_NODE = 0x20 + TREE_POSITION_DISCONNECTED = 0x00 + +class UserDataHandler: + """Class giving the operation constants for UserDataHandler.handle().""" + + # Based on DOM Level 3 (WD 9 April 2002) + + NODE_CLONED = 1 + NODE_IMPORTED = 2 + NODE_DELETED = 3 + NODE_RENAMED = 4 + +class DOMError: + """Class giving constants for error severity.""" + + # Based on DOM Level 3 (WD 9 April 2002) + + SEVERITY_WARNING = 0 + SEVERITY_ERROR = 1 + SEVERITY_FATAL_ERROR = 2 + + +# DOMException codes +INDEX_SIZE_ERR = 1 +DOMSTRING_SIZE_ERR = 2 +HIERARCHY_REQUEST_ERR = 3 +WRONG_DOCUMENT_ERR = 4 +INVALID_CHARACTER_ERR = 5 +NO_DATA_ALLOWED_ERR = 6 +NO_MODIFICATION_ALLOWED_ERR = 7 +NOT_FOUND_ERR = 8 +NOT_SUPPORTED_ERR = 9 +INUSE_ATTRIBUTE_ERR = 10 +# DOM Level 2 +INVALID_STATE_ERR = 11 +SYNTAX_ERR = 12 +INVALID_MODIFICATION_ERR = 13 +NAMESPACE_ERR = 14 +INVALID_ACCESS_ERR = 15 +# DOM Level 3 +VALIDATION_ERR = 16 + +# EventException codes +UNSPECIFIED_EVENT_TYPE_ERR = 0 + +# Fourthought specific codes +FT_EXCEPTION_BASE = 1000 +XML_PARSE_ERR = FT_EXCEPTION_BASE + 1 + +#RangeException codes +BAD_BOUNDARYPOINTS_ERR = 1 +INVALID_NODE_TYPE_ERR = 2 + + +class DOMException(Exception): + def __init__(self, code, msg=''): + self.code = code + self.msg = msg or DOMExceptionStrings[code] + + def __str__(self): + return self.msg + +class EventException(Exception): + def __init__(self, code, msg=''): + self.code = code + self.msg = msg or EventExceptionStrings[code] + return + + def __str__(self): + return self.msg + +class RangeException(Exception): + def __init__(self, code, msg): + self.code = code + self.msg = msg or RangeExceptionStrings[code] + Exception.__init__(self, self.msg) + +class FtException(Exception): + def __init__(self, code, *args): + self.code = code + self.msg = FtExceptionStrings[code] % args + return + + def __str__(self): + return self.msg + +class IndexSizeErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, INDEX_SIZE_ERR, msg) + +class DomstringSizeErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, DOMSTRING_SIZE_ERR, msg) + +# DOMStringSizeErr was accidentally introduced in rev 1.14 of this +# file, and was released as part of PyXML 0.6.4, 0.6.5, 0.6.6, 0.7, +# and 0.7.1. It has never been part of the Python DOM API, although +# it better matches the W3C recommendation. It should remain for +# compatibility, unfortunately. +# +DOMStringSizeErr = DomstringSizeErr + +class HierarchyRequestErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, HIERARCHY_REQUEST_ERR, msg) + +class WrongDocumentErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, WRONG_DOCUMENT_ERR, msg) + +class InvalidCharacterErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, INVALID_CHARACTER_ERR, msg) + +class NoDataAllowedErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, NO_DATA_ALLOWED_ERR, msg) + +class NoModificationAllowedErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, NO_MODIFICATION_ALLOWED_ERR, msg) + +class NotFoundErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, NOT_FOUND_ERR, msg) + +class NotSupportedErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, NOT_SUPPORTED_ERR, msg) + +class InuseAttributeErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, INUSE_ATTRIBUTE_ERR, msg) + +class InvalidStateErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, INVALID_STATE_ERR, msg) + +class SyntaxErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, SYNTAX_ERR, msg) + +class InvalidModificationErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, INVALID_MODIFICATION_ERR, msg) + +class NamespaceErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, NAMESPACE_ERR, msg) + +class InvalidAccessErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, INVALID_ACCESS_ERR, msg) + +class ValidationErr(DOMException): + def __init__(self, msg=''): + DOMException.__init__(self, VALIDATION_ERR, msg) + +class UnspecifiedEventTypeErr(EventException): + def __init__(self, msg=''): + EventException.__init__(self, UNSPECIFIED_EVENT_TYPE_ERR, msg) + +class XmlParseErr(FtException): + def __init__(self, msg=''): + FtException.__init__(self, XML_PARSE_ERR, msg) + +#Specific Range Exceptions +class BadBoundaryPointsErr(RangeException): + def __init__(self, msg=''): + RangeException.__init__(self, BAD_BOUNDARYPOINTS_ERR, msg) + +class InvalidNodeTypeErr(RangeException): + def __init__(self, msg=''): + RangeException.__init__(self, INVALID_NODE_TYPE_ERR, msg) + +XML_NAMESPACE = "http://www.w3.org/XML/1998/namespace" +XMLNS_NAMESPACE = "http://www.w3.org/2000/xmlns/" +XHTML_NAMESPACE = "http://www.w3.org/1999/xhtml" +EMPTY_NAMESPACE = None +EMPTY_PREFIX = None + +import MessageSource +DOMExceptionStrings = MessageSource.__dict__['DOMExceptionStrings'] +EventExceptionStrings = MessageSource.__dict__['EventExceptionStrings'] +FtExceptionStrings = MessageSource.__dict__['FtExceptionStrings'] +RangeExceptionStrings = MessageSource.__dict__['RangeExceptionStrings'] + +from domreg import getDOMImplementation,registerDOMImplementation Added: trunk/jython/Lib/xml/dom/domreg.py =================================================================== --- trunk/jython/Lib/xml/dom/domreg.py (rev 0) +++ trunk/jython/Lib/xml/dom/domreg.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,99 @@ +"""Registration facilities for DOM. This module should not be used +directly. Instead, the functions getDOMImplementation and +registerDOMImplementation should be imported from xml.dom.""" + +from xml.dom.minicompat import * # isinstance, StringTypes + +# This is a list of well-known implementations. Well-known names +# should be published by posting to xm...@py..., and are +# subsequently recorded in this file. + +well_known_implementations = { + 'minidom':'xml.dom.minidom', + '4DOM': 'xml.dom.DOMImplementation', + } + +# DOM implementations not officially registered should register +# themselves with their + +registered = {} + +def registerDOMImplementation(name, factory): + """registerDOMImplementation(name, factory) + + Register the factory function with the name. The factory function + should return an object which implements the DOMImplementation + interface. The factory function can either return the same object, + or a new one (e.g. if that implementation supports some + customization).""" + + registered[name] = factory + +def _good_enough(dom, features): + "_good_enough(dom, features) -> Return 1 if the dom offers the features" + for f,v in features: + if not dom.hasFeature(f,v): + return 0 + return 1 + +def getDOMImplementation(name = None, features = ()): + """getDOMImplementation(name = None, features = ()) -> DOM implementation. + + Return a suitable DOM implementation. The name is either + well-known, the module name of a DOM implementation, or None. If + it is not None, imports the corresponding module and returns + DOMImplementation object if the import succeeds. + + If name is not given, consider the available implementations to + find one with the required feature set. If no implementation can + be found, raise an ImportError. The features list must be a sequence + of (feature, version) pairs which are passed to hasFeature.""" + + import os + creator = None + mod = well_known_implementations.get(name) + if mod: + mod = __import__(mod, {}, {}, ['getDOMImplementation']) + return mod.getDOMImplementation() + elif name: + return registered[name]() + elif os.environ.has_key("PYTHON_DOM"): + return getDOMImplementation(name = os.environ["PYTHON_DOM"]) + + # User did not specify a name, try implementations in arbitrary + # order, returning the one that has the required features + if isinstance(features, StringTypes): + features = _parse_feature_string(features) + for creator in registered.values(): + dom = creator() + if _good_enough(dom, features): + return dom + + for creator in well_known_implementations.keys(): + try: + dom = getDOMImplementation(name = creator) + except StandardError: # typically ImportError, or AttributeError + continue + if _good_enough(dom, features): + return dom + + raise ImportError,"no suitable DOM implementation found" + +def _parse_feature_string(s): + features = [] + parts = s.split() + i = 0 + length = len(parts) + while i < length: + feature = parts[i] + if feature[0] in "0123456789": + raise ValueError, "bad feature name: " + `feature` + i = i + 1 + version = None + if i < length: + v = parts[i] + if v[0] in "0123456789": + i = i + 1 + version = v + features.append((feature, version)) + return tuple(features) Added: trunk/jython/Lib/xml/dom/minicompat.py =================================================================== --- trunk/jython/Lib/xml/dom/minicompat.py (rev 0) +++ trunk/jython/Lib/xml/dom/minicompat.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,184 @@ +"""Python version compatibility support for minidom.""" + +# This module should only be imported using "import *". +# +# The following names are defined: +# +# isinstance -- version of the isinstance() function that accepts +# tuples as the second parameter regardless of the +# Python version +# +# NodeList -- lightest possible NodeList implementation +# +# EmptyNodeList -- lightest possible NodeList that is guarateed to +# remain empty (immutable) +# +# StringTypes -- tuple of defined string types +# +# GetattrMagic -- base class used to make _get_<attr> be magically +# invoked when available +# defproperty -- function used in conjunction with GetattrMagic; +# using these together is needed to make them work +# as efficiently as possible in both Python 2.2+ +# and older versions. For example: +# +# class MyClass(GetattrMagic): +# def _get_myattr(self): +# return something +# +# defproperty(MyClass, "myattr", +# "return some value") +# +# For Python 2.2 and newer, this will construct a +# property object on the class, which avoids +# needing to override __getattr__(). It will only +# work for read-only attributes. +# +# For older versions of Python, inheriting from +# GetattrMagic will use the traditional +# __getattr__() hackery to achieve the same effect, +# but less efficiently. +# +# defproperty() should be used for each version of +# the relevant _get_<property>() function. +# +# NewStyle -- base class to cause __slots__ to be honored in +# the new world +# +# True, False -- only for Python 2.2 and earlier + +__all__ = ["NodeList", "EmptyNodeList", "NewStyle", + "StringTypes", "defproperty", "GetattrMagic"] + +import xml.dom + +try: + unicode +except NameError: + StringTypes = type(''), +else: + StringTypes = type(''), type(unicode('')) + + +# define True and False only if not defined as built-ins +try: + True +except NameError: + True = 1 + False = 0 + __all__.extend(["True", "False"]) + + +try: + isinstance('', StringTypes) +except TypeError: + # + # Wrap isinstance() to make it compatible with the version in + # Python 2.2 and newer. + # + _isinstance = isinstance + def isinstance(obj, type_or_seq): + try: + return _isinstance(obj, type_or_seq) + except TypeError: + for t in type_or_seq: + if _isinstance(obj, t): + return 1 + return 0 + __all__.append("isinstance") + + +if list is type([]): + class NodeList(list): + __slots__ = () + + def item(self, index): + if 0 <= index < len(self): + return self[index] + + def _get_length(self): + return len(self) + + def _set_length(self, value): + raise xml.dom.NoModificationAllowedErr( + "attempt to modify read-only attribute 'length'") + + length = property(_get_length, _set_length, + doc="The number of nodes in the NodeList.") + + def __getstate__(self): + return list(self) + + def __setstate__(self, state): + self[:] = state + + class EmptyNodeList(tuple): + __slots__ = () + + def __add__(self, other): + NL = NodeList() + NL.extend(other) + return NL + + def __radd__(self, other): + NL = NodeList() + NL.extend(other) + return NL + + def item(self, index): + return None + + def _get_length(self): + return 0 + + def _set_length(self, value): + raise xml.dom.NoModificationAllowedErr( + "attempt to modify read-only attribute 'length'") + + length = property(_get_length, _set_length, + doc="The number of nodes in the NodeList.") + +else: + def NodeList(): + return [] + + def EmptyNodeList(): + return [] + + +try: + property +except NameError: + def defproperty(klass, name, doc): + # taken care of by the base __getattr__() + pass + + class GetattrMagic: + def __getattr__(self, key): + if key.startswith("_"): + raise AttributeError, key + + try: + get = getattr(self, "_get_" + key) + except AttributeError: + raise AttributeError, key + return get() + + class NewStyle: + pass + +else: + def defproperty(klass, name, doc): + get = getattr(klass, ("_get_" + name)).im_func + def set(self, value, name=name): + raise xml.dom.NoModificationAllowedErr( + "attempt to modify read-only attribute " + repr(name)) + assert not hasattr(klass, "_set_" + name), \ + "expected not to find _set_" + name + prop = property(get, set, doc=doc) + setattr(klass, name, prop) + + class GetattrMagic: + pass + + NewStyle = object Added: trunk/jython/Lib/xml/dom/minidom.py =================================================================== --- trunk/jython/Lib/xml/dom/minidom.py (rev 0) +++ trunk/jython/Lib/xml/dom/minidom.py 2006-08-31 03:13:09 UTC (rev 2920) @@ -0,0 +1,1943 @@ +"""\ +minidom.py -- a lightweight DOM implementation. + +parse("foo.xml") + +parseString("<foo><bar/></foo>") + +Todo: +===== + * convenience methods for getting elements and text. + * more testing + * bring some of the writer and linearizer code into conformance with this + interface + * SAX 2 namespaces +""" + +import xml.dom + +from xml.dom import EMPTY_NAMESPACE, EMPTY_PREFIX, XMLNS_NAMESPACE, domreg +from xml.dom.minicompat import * +from xml.dom.xmlbuilder import DOMImplementationLS, DocumentLS + +_TupleType = type(()) + +# This is used by the ID-cache invalidation checks; the list isn't +# actually complete, since the nodes being checked will never be the +# DOCUMENT_NODE or DOCUMENT_FRAGMENT_NODE. (The node being checked is +# the node being added or removed, not the node being modified.) +# +_nodeTypes_with_children = (xml.dom.Node.ELEMENT_NODE, + xml.dom.Node.ENTITY_REFERENCE_NODE) + + +class Node(xml.dom.Node, GetattrMagic): + namespaceURI = None # this is non-null only for elements and attributes + parentNode = None + ownerDocument = None + nextSibling = None + previousSibling = None + + prefix = EMPTY_PREFIX # non-null only for NS elements and attributes + + def __nonzero__(self): + return True + + def toxml(self, encoding = None): + return self.toprettyxml("", "", encoding) + + def toprettyxml(self, indent="\t", newl="\n", encoding = None): + # indent = the indentation string to prepend, per level + # newl = the newline string to append + writer = _get_StringIO() + if encoding is not None: + import codecs + # Can't use codecs.getwriter to preserve 2.0 compatibility + writer = codecs.lookup(encoding)[3](writer) + if self.nodeType == Node.DOCUMENT_NODE: + # Can pass encoding only to document, to put it into XML header + self.writexml(writer, "", indent, newl, encoding) + else: + self.writexml(writer, "", indent, newl) + return writer.getvalue() + + def hasAttributes(self): + return False + + def hasChildNodes(self): + if self.childNodes: + return True + else: + return False + + def _get_childNodes(self): + return self.childNodes + + def _get_firstChild(self): + if self.childNodes: + return self.childNodes[0] + + def _get_lastChild(self): + if self.childNodes: + return self.childNodes[-1] + + def insertBefore(self, newChild, refChild): + if newChild.nodeType == self.DOCUMENT_FRAGMENT_NODE: + for c in tuple(newChild.childNodes): + self.insertBefore(c, refChild) + ### The DOM does not clearly specify what to return in this case + return newChild + if newChild.nodeType not in self._child_node_types: + raise xml.dom.HierarchyRequestErr( + "%s cannot be child of %s" % (repr(newChild), repr(self))) + if newChild.parentNode is not None: + newChild.parentNode.removeChild(newChild) + if refChild is None: + self.appendChild(newChild) + else: + try: + index = self.childNodes.index(refChild) + except ValueError: + raise xml.dom.NotFoundErr() + if newChild.nodeType in _nodeTypes_with_children: + _clear_id_cache(self) + self.childNodes.insert(index, newChild) + newChild.nextSibling = refChild + refChild.previousSibling = newChild + if index: + node = self.childNodes[index-1] + node.nextSibling = newChild + newChild.previousSibling = node + else: + newChild.previousSibling = None + newChild.parentNode = self + return newChild + + def appendChild(self, node): + if node.nodeType == self.DOCUMENT_FRAGMENT_NODE: + for c in tuple(node.childNodes): + self.appendChild(c) + ### The DOM does not clearly specify what to return in this case + return node + if node.nodeType not in self._child_node_types: + raise xml.dom.HierarchyRequestErr( + "%s cannot be child of %s" % (repr(node), repr(self))) + elif node.nodeType in _nodeTypes_with_children: + _clear_id_cache(self) + if node.parentNode is not None: + node.parentNode.removeChild(node) + _append_child(self, node) + node.nextSibling = None + return node + + def replaceChild(self, newChild, oldChild): + if newChild.nodeType == self.DOCUMENT_FRAGMENT_NODE: + refChild = oldChild.nextSibling + self.removeChild(oldChild) + return self.insertBefore(newChild, refChild) + if newChild.nodeType not in self._child_node_types: + raise xml.dom.HierarchyRequestErr( + "%s cannot be child of %s" % (repr(newChild), repr(self))) + if newChild.parentNode is not None: + newChild.parentNode.removeChild(newChild) + if newChild is oldChild: + return + try: + index = self.childNodes.index(oldChild) + except ValueError: + raise xml.dom.NotFoundErr() + if (newChild.nodeType in _nodeTypes_with_children + or oldChild.nodeType in _nodeTypes_with_children): + _clear_id_cache(self) + self.childNodes[index] = newChild + newChild.parentNode = self + oldChild.parentNode = None + newChild.nextSibling = oldChild.nextSibling + newChild.previousSibling = oldChild.previousSibling + oldChild.nextSibling = None + oldChild.previousSibling = None + if newChild.previousSibling: + newChild.previousSibling.nextSibling = newChild + if newChild.nextSibling: + newChild.nextSibling.previousSibling = newChild + return oldChild + + def removeChild(self, oldChild): + try: + self.childNodes.remove(oldChild) + except ValueError: + raise xml.dom.NotFoundErr() + if oldChild.nextSibling is not None: + oldChild.nextSibling.previousSibling = oldChild.previousSibling + if oldChild.previousSibling is not None: + oldChild.previousSibling.nextSibling = oldChild.nextSibling + oldChild.nextSibling = oldChild.previousSibling = None + if oldChild.nodeType in _nodeTypes_with_children: + _clear_id_cache(self) + + oldChild.parentNode = None + return oldChild + + def normalize(self): + L = [] + for child in self.childNodes: + if child.nodeType == Node.TEXT_NODE: + data = child.data + if data and L and L[-1].nodeType == child.nodeType: + # collapse text node + node = L[-1] + node.data = node.data + child.data + node.nextSibling = child.nextSibling + child.unlink() + elif data: + if L: + L[-1].nextSibling = child + child.previousSibling = L[-1] + else: + child.previousSibling = None + L.append(child) + else: + # empty text node; discard + child.unlink() + else: + if L: + L[-1].nextSibling = child + child.previousSibling = L[-1] + else: + child.previousSibling = None + L.append(child) + if child.nodeType == Node.ELEMENT_NODE: + child.normalize() + if self.childNodes: + self.childNodes[:] = L + return + + def cloneNode(self, deep): + return _clone_node(self, deep, self.ownerDocument or self) + + def isSupported(self, feature, version): + return self.ownerDocument.implementation.hasFeature(feature, version) + + def _get_localName(self): + # Overridden in Element and Attr where localName can be Non-Null + return None + + # Node interfaces from Level 3 (WD 9 April 2002) + + def isSameNode(self, other): + return self is other + + def getInterface(self, feature): + if self.isSupported(feature, None): + return self + else: + return None + + # The "user data" functions use a dictionary that is only present + # if some user data has been set, so be careful not to assume it + # exists. + + def getUserData(self, key): + try: + return self._user_data[key][0] + except (AttributeError, KeyError): + return None + + def setUserData(self, key, data, handler): + old = None + try: + d = self._user_data + except AttributeError: + d = {} + self._user_data = d + if d.has_key(key): + old = d[key][0] + if data is None: + # ignore handlers passed for None + handler = None + if old is not None: + del d[key] + else: + d[key] = (data, handler) + return old + + def _call_user_data_handler(self, operation, src, dst): + if hasattr(self, "_user_data"): + for key, (data, handler) in self._user_data.items(): + if handler is not None: + handler.handle(operation, key, data, src, dst) + + # minidom-specific API: + + def unlink(self): + self.parentNode = self.ownerDocument = None + if self.childNodes: + for child in self.childNodes: + child.unlink() + self.childNodes = NodeList() + self.previousSibling = None + self.nextSibling = None + +defproperty(Node, "firstChild", doc="First child node, or None.") +defproperty(Node, "lastChild", doc="Last child node, or None.") +defproperty(Node, "localName", doc="Namespace-local name of this node.") + + +def _append_child(self, node): + # fast path with less checks; usable by DOM builders if careful + childNodes = self.childNodes + if childNodes: + last = childNodes[-1] + node.__dict__["previousSibling"] = last + last.__dict__["nextSibling"] = node + childNodes.append(node) + node.__dict__["parentNode"] = self + +def _in_document(node): + # return True iff node is part of a document tree + while node is not None: + if node.nodeType == Node.DOCUMENT_NODE: + return True + node = node.parentNode + return False + +def _write_data(writer, data): + "Writes datachars to writer." + data = data.replace("&", "&").replace("<", "<") + data = data.replace("\"", """).replace(">", ">") + writer.write(data) + +def _get_elements_by_tagName_helper(parent, name, rc): + for node in parent.childNodes: + if node.nodeType == Node.ELEMENT_NODE and \ + (name == "*" or node.tagName == name): + rc.append(node) + _get_elements_by_tagName_helper(node, name, rc) + return rc + +def _get_elements_by_tagName_ns_helper(parent, nsURI, localName, rc): + for node in parent.childNodes: + if node.nodeType == Node.ELEMENT_NODE: + if ((localName == "*" or node.localName == localName) and + (nsURI == "*" or node.namespaceURI == nsURI)): + rc.append(node) + _get_elements_by_tagName_ns_helper(node, nsURI, localName, rc) + return rc + +class DocumentFragment(Node): + nodeType = Node.DOCUMENT_FRAGMENT_NODE + nodeName = "#document-fragment" + nodeValue = None + attributes = None + parentNode = None + _child_node_types = (Node.ELEMENT_NODE, + Node.TEXT_NODE, + Node.CDATA_SECTION_NODE, + Node.ENTITY_REFERENCE_NODE, + Node.PROCESSING_INSTRUCTION_NODE, + Node.COMMENT_NODE, + Node.NOTATION_NODE) + + def __init__(self): + self.childNodes = NodeList() + + +class Attr(Node): + nodeType = Node.ATTRIBUTE_NODE + attributes = None + ownerElement = None + specified = False + _is_id = False + + _child_node_types = (Node.TEXT_NODE, Node.ENTITY_REFERENCE_NODE) + + def __init__(self, qName, namespaceURI=EMPTY_NAMESPACE, localName=None, + prefix=None): + # skip setattr for performance + d = self.__dict__ + d["nodeName"] = d["name"] = qName + d["namespaceURI"] = namespaceURI + d["prefix"] = prefix + d['childNodes'] = NodeList() + + # Add the single child node that represents the value of the attr + self.childNodes.append(Text()) + + # nodeValue and value are set elsewhere + + def _get_localName(self): + return self.nodeName.split(":", 1)[-1] + + def _get_name(self): + return self.name + + def _get_specified(self): + return self.specified + + def __setattr__(self, name, value): + d = self.__dict__ + if name in ("value", "nodeValue"): + d["value"] = d["nodeValue"] = value + d2 = self.childNodes[0].__dict__ + d2["data"] = d2["nodeValue"] = value + if self.ownerElement is not None: + _clear_id_cache(self.ownerElement) + elif name in ("name", "nodeName"): + d["name"] = d["nodeName"] = value + if self.ownerElement is not None: + _clear_id_cache(self.ownerElement) + else: + d[name] = value + + def _set_prefix(self, prefix): + nsuri = self.namespaceURI + if prefix == "xmlns": + if nsuri and nsuri != XMLNS_NAMESPACE: + raise xml.dom.NamespaceErr( + "illegal use of 'xmlns' prefix for the wrong namespace") + d = self.__dict__ + d['prefix'] = prefix + if prefix is None: + newName = self.localName + else: + newName = "%s:%s" % (prefix, self.localName) + if self.ownerElement: + _clear_id_cache(self.ownerElement) + d['nodeName'] = d['name'] = newName + + def _set_value(self, value): + d = self.__dict__ + d['value'] = d['nodeValue'] = value + if self.ownerElement: + _clear_id_cache(self.ownerElement) + self.childNodes[0].data = value + + def unlink(self): + # This implementation does not call the base implementation + # since most of that is not needed, and the expense of the + # method call is not warranted. We duplicate the removal of + # children, but that's all we needed from the base class. + elem = self.ownerElement + if elem is not None: + del elem._attrs[self.nodeName] + del elem._attrsNS[(self.namespaceURI, self.localName)] + if self._is_id: + self._is_id = False + elem._magic_id_nodes -= 1 + self.ownerDocument._magic_id_count -= 1 + for child in self.childNodes: + child.unlink() + del self.childNodes[:] + + def _get_isId(self): + if self._is_id: + return True + doc = self.ownerDocument + elem = self.ownerElement + if doc is None or elem is None: + return False + + info = doc._get_elem_info(elem) + if info is None: + return False + if self.namespaceURI: + return info.isIdNS(self.namespaceURI, self.localName) + else: + return info.isId(self.nodeName) + + def _get_schemaType(self): + doc = self.ownerDocument + elem = self.ownerElement + if doc is None or elem is None: + return _no_type + + info = doc._get_elem_info(elem) + if info is None: + return _no_type + if self.namespaceURI: + return info.getAttributeTypeNS(self.namespaceURI, self.localName) + else: + return info.getAttributeType(self.nodeName) + +defproperty(Attr, "isId", doc="True if this attribute is an ID.") +defproperty(Attr, "localName", doc="Namespace-local name of this attribute.") +defproperty(Attr, "schemaType", doc="Schema type for this attribute.") + + +class NamedNodeMap(NewStyle, GetattrMagic): + """The attribute list is a transient interface to the underlying + dictionaries. Mutations here will change the underlying element's + dictionary. + + Ordering is imposed artificially and does not reflect the order of + attributes as found in an input document. + """ + + __slots__ = ('_attrs', '_attrsNS', '_ownerElement') + + def __init__(self, attrs, attrsNS, ownerElement): + self._attrs = attrs + self._attrsNS = attrsNS + self._ownerElement = ownerElement + + def _get_length(self): + return len(self._attrs) + + def item(self, index): + try: + return self[self._attrs.keys()[index]] + except IndexError: + return None + + def items(self): + L = [] + for node in self._attrs.values(): + L.append((node.nodeName, node.value)) + return L + + def itemsNS(self): + L = [] + for n... [truncated message content] |
From: <cg...@us...> - 2007-01-31 05:11:48
|
Revision: 3068 http://svn.sourceforge.net/jython/?rev=3068&view=rev Author: cgroves Date: 2007-01-30 21:11:46 -0800 (Tue, 30 Jan 2007) Log Message: ----------- copy various path attributes into path module and add realpath method to fix bug #1534547 Modified Paths: -------------- trunk/jython/Lib/javaos.py trunk/jython/Lib/javapath.py Modified: trunk/jython/Lib/javaos.py =================================================================== --- trunk/jython/Lib/javaos.py 2007-01-31 04:43:53 UTC (rev 3067) +++ trunk/jython/Lib/javaos.py 2007-01-31 05:11:46 UTC (rev 3068) @@ -88,6 +88,13 @@ extsep = '/' else: extsep = '.' +path.curdir = curdir +path.pardir = pardir +path.sep = sep +path.altsep = altsep +path.pathsep = pathsep +path.defpath = defpath +path.extsep = extsep def _exit(n=0): java.lang.System.exit(n) Modified: trunk/jython/Lib/javapath.py =================================================================== --- trunk/jython/Lib/javapath.py 2007-01-31 04:43:53 UTC (rev 3067) +++ trunk/jython/Lib/javapath.py 2007-01-31 05:11:46 UTC (rev 3068) @@ -20,6 +20,7 @@ from java.lang import System import os + def _tostr(s, method): if isinstance(s, "".__class__): return s @@ -238,6 +239,9 @@ path = _tostr(path, "abspath") return File(path).getAbsolutePath() +def realpath(path): + path = _tostr(path, "realpath") + return File(path).getCanonicalPath() def getsize(path): path = _tostr(path, "getsize") This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <cg...@us...> - 2007-04-30 06:50:55
|
Revision: 3200 http://svn.sourceforge.net/jython/?rev=3200&view=rev Author: cgroves Date: 2007-04-29 23:50:50 -0700 (Sun, 29 Apr 2007) Log Message: ----------- Don't check co_code on Jython since it doesn't exist. The means getargs doesn't correctly return anonymous tuple arguments Modified Paths: -------------- trunk/jython/Lib/inspect.py trunk/jython/Lib/test/test_inspect.py Modified: trunk/jython/Lib/inspect.py =================================================================== --- trunk/jython/Lib/inspect.py 2007-04-30 06:45:58 UTC (rev 3199) +++ trunk/jython/Lib/inspect.py 2007-04-30 06:50:50 UTC (rev 3200) @@ -27,7 +27,7 @@ __author__ = 'Ka-Ping Yee <pi...@lf...>' __date__ = '1 Jan 2001' -import sys, os, types, string, re, dis, imp, tokenize +import sys, os, types, string, re, imp, tokenize # ----------------------------------------------------------- type-checking def ismodule(object): @@ -571,37 +571,39 @@ 'varargs' and 'varkw' are the names of the * and ** arguments or None.""" if not iscode(co): raise TypeError, 'arg is not a code object' - code = co.co_code nargs = co.co_argcount names = co.co_varnames args = list(names[:nargs]) step = 0 # The following acrobatics are for anonymous (tuple) arguments. - for i in range(nargs): - if args[i][:1] in ['', '.']: - stack, remain, count = [], [], [] - while step < len(code): - op = ord(code[step]) - step = step + 1 - if op >= dis.HAVE_ARGUMENT: - opname = dis.opname[op] - value = ord(code[step]) + ord(code[step+1])*256 - step = step + 2 - if opname in ['UNPACK_TUPLE', 'UNPACK_SEQUENCE']: - remain.append(value) - count.append(value) - elif opname == 'STORE_FAST': - stack.append(names[value]) - remain[-1] = remain[-1] - 1 - while remain[-1] == 0: - remain.pop() - size = count.pop() - stack[-size:] = [stack[-size:]] + if not sys.platform.startswith('java'):#Jython doesn't have co_code + code = co.co_code + import dis + for i in range(nargs): + if args[i][:1] in ['', '.']: + stack, remain, count = [], [], [] + while step < len(code): + op = ord(code[step]) + step = step + 1 + if op >= dis.HAVE_ARGUMENT: + opname = dis.opname[op] + value = ord(code[step]) + ord(code[step+1])*256 + step = step + 2 + if opname in ['UNPACK_TUPLE', 'UNPACK_SEQUENCE']: + remain.append(value) + count.append(value) + elif opname == 'STORE_FAST': + stack.append(names[value]) + remain[-1] = remain[-1] - 1 + while remain[-1] == 0: + remain.pop() + size = count.pop() + stack[-size:] = [stack[-size:]] + if not remain: break + remain[-1] = remain[-1] - 1 if not remain: break - remain[-1] = remain[-1] - 1 - if not remain: break - args[i] = stack[0] + args[i] = stack[0] varargs = None if co.co_flags & CO_VARARGS: Modified: trunk/jython/Lib/test/test_inspect.py =================================================================== --- trunk/jython/Lib/test/test_inspect.py 2007-04-30 06:45:58 UTC (rev 3199) +++ trunk/jython/Lib/test/test_inspect.py 2007-04-30 06:50:50 UTC (rev 3200) @@ -62,7 +62,7 @@ # getsourcefile, getcomments, getsource, getclasstree, getargspec, # getargvalues, formatargspec, formatargvalues, currentframe, stack, trace -from test_support import TestFailed, TESTFN +from test_support import TestFailed, TESTFN, is_jython import sys, imp, os, string def test(assertion, message, *args): @@ -94,7 +94,7 @@ except: tb = sys.exc_traceback -istest(inspect.isbuiltin, 'sys.exit') +istest(inspect.isbuiltin, 'ord') istest(inspect.isbuiltin, '[].append') istest(inspect.isclass, 'mod.StupidGit') istest(inspect.iscode, 'mod.spam.func_code') @@ -155,14 +155,15 @@ test(defaults == None, 'mod.eggs defaults') test(inspect.formatargspec(args, varargs, varkw, defaults) == '(x, y)', 'mod.eggs formatted argspec') -args, varargs, varkw, defaults = inspect.getargspec(mod.spam) -test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.spam args') -test(varargs == 'g', 'mod.spam varargs') -test(varkw == 'h', 'mod.spam varkw') -test(defaults == (3, (4, (5,))), 'mod.spam defaults') -test(inspect.formatargspec(args, varargs, varkw, defaults) == - '(a, b, c, d=3, (e, (f,))=(4, (5,)), *g, **h)', - 'mod.spam formatted argspec') +if not is_jython:#Jython can't handle this without co_code + args, varargs, varkw, defaults = inspect.getargspec(mod.spam) + test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.spam args') + test(varargs == 'g', 'mod.spam varargs') + test(varkw == 'h', 'mod.spam varkw') + test(defaults == (3, (4, (5,))), 'mod.spam defaults') + test(inspect.formatargspec(args, varargs, varkw, defaults) == + '(a, b, c, d=3, (e, (f,))=(4, (5,)), *g, **h)', + 'mod.spam formatted argspec') git.abuse(7, 8, 9) @@ -200,13 +201,14 @@ test(inspect.formatargvalues(args, varargs, varkw, locals) == '(x=11, y=14)', 'mod.fr formatted argvalues') -args, varargs, varkw, locals = inspect.getargvalues(mod.fr.f_back) -test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.fr.f_back args') -test(varargs == 'g', 'mod.fr.f_back varargs') -test(varkw == 'h', 'mod.fr.f_back varkw') -test(inspect.formatargvalues(args, varargs, varkw, locals) == - '(a=7, b=8, c=9, d=3, (e=4, (f=5,)), *g=(), **h={})', - 'mod.fr.f_back formatted argvalues') +if not is_jython: + args, varargs, varkw, locals = inspect.getargvalues(mod.fr.f_back) + test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.fr.f_back args') + test(varargs == 'g', 'mod.fr.f_back varargs') + test(varkw == 'h', 'mod.fr.f_back varkw') + test(inspect.formatargvalues(args, varargs, varkw, locals) == + '(a=7, b=8, c=9, d=3, (e=4, (f=5,)), *g=(), **h={})', + 'mod.fr.f_back formatted argvalues') for fname in files_to_clean_up: try: This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <cg...@us...> - 2007-04-30 06:46:05
|
Revision: 3199 http://svn.sourceforge.net/jython/?rev=3199&view=rev Author: cgroves Date: 2007-04-29 23:45:58 -0700 (Sun, 29 Apr 2007) Log Message: ----------- pulling clean copies of these files from http://svn.python.org/projects/python/branches/release22-maint/Lib@54948 to overlay Jython fixes on top of Modified Paths: -------------- trunk/jython/Lib/inspect.py trunk/jython/Lib/test/test_inspect.py Modified: trunk/jython/Lib/inspect.py =================================================================== --- trunk/jython/Lib/inspect.py 2007-04-29 08:59:39 UTC (rev 3198) +++ trunk/jython/Lib/inspect.py 2007-04-30 06:45:58 UTC (rev 3199) @@ -1,23 +1,785 @@ -# -# Very simple version of inspect, just enough is supported for -# doctest to work. -# +"""Get useful information from live Python objects. -import org.python.core as _core +This module encapsulates the interface provided by the internal special +attributes (func_*, co_*, im_*, tb_*, etc.) in a friendlier fashion. +It also provides some help for examining source code and class layout. -def isclass(cls): - return isinstance(cls, _core.PyClass) +Here are some of the useful functions provided by this module: -def isfunction(func): - return isinstance(func, _core.PyFunction) + ismodule(), isclass(), ismethod(), isfunction(), istraceback(), + isframe(), iscode(), isbuiltin(), isroutine() - check object types + getmembers() - get members of an object that satisfy a given condition -def ismodule(mod): - return isinstance(mod, _core.PyModule) + getfile(), getsourcefile(), getsource() - find an object's source code + getdoc(), getcomments() - get documentation on an object + getmodule() - determine the module that an object came from + getclasstree() - arrange classes so as to represent their hierarchy -def ismethod(meth): - return isinstance(meth, _core.PyMethod) + getargspec(), getargvalues() - get info about function arguments + formatargspec(), formatargvalues() - format an argument spec + getouterframes(), getinnerframes() - get info about frames + currentframe() - get the current stack frame + stack(), trace() - get info about frames on the stack or in a traceback +""" -def classify_class_attrs(obj): - return [] +# This module is in the public domain. No warranties. +__author__ = 'Ka-Ping Yee <pi...@lf...>' +__date__ = '1 Jan 2001' +import sys, os, types, string, re, dis, imp, tokenize + +# ----------------------------------------------------------- type-checking +def ismodule(object): + """Return true if the object is a module. + + Module objects provide these attributes: + __doc__ documentation string + __file__ filename (missing for built-in modules)""" + return isinstance(object, types.ModuleType) + +def isclass(object): + """Return true if the object is a class. + + Class objects provide these attributes: + __doc__ documentation string + __module__ name of module in which this class was defined""" + return isinstance(object, types.ClassType) or hasattr(object, '__bases__') + +def ismethod(object): + """Return true if the object is an instance method. + + Instance method objects provide these attributes: + __doc__ documentation string + __name__ name with which this method was defined + im_class class object in which this method belongs + im_func function object containing implementation of method + im_self instance to which this method is bound, or None""" + return isinstance(object, types.MethodType) + +def ismethoddescriptor(object): + """Return true if the object is a method descriptor. + + But not if ismethod() or isclass() or isfunction() are true. + + This is new in Python 2.2, and, for example, is true of int.__add__. + An object passing this test has a __get__ attribute but not a __set__ + attribute, but beyond that the set of attributes varies. __name__ is + usually sensible, and __doc__ often is. + + Methods implemented via descriptors that also pass one of the other + tests return false from the ismethoddescriptor() test, simply because + the other tests promise more -- you can, e.g., count on having the + im_func attribute (etc) when an object passes ismethod().""" + return (hasattr(object, "__get__") + and not hasattr(object, "__set__") # else it's a data descriptor + and not ismethod(object) # mutual exclusion + and not isfunction(object) + and not isclass(object)) + +def isfunction(object): + """Return true if the object is a user-defined function. + + Function objects provide these attributes: + __doc__ documentation string + __name__ name with which this function was defined + func_code code object containing compiled function bytecode + func_defaults tuple of any default values for arguments + func_doc (same as __doc__) + func_globals global namespace in which this function was defined + func_name (same as __name__)""" + return isinstance(object, types.FunctionType) + +def istraceback(object): + """Return true if the object is a traceback. + + Traceback objects provide these attributes: + tb_frame frame object at this level + tb_lasti index of last attempted instruction in bytecode + tb_lineno current line number in Python source code + tb_next next inner traceback object (called by this level)""" + return isinstance(object, types.TracebackType) + +def isframe(object): + """Return true if the object is a frame object. + + Frame objects provide these attributes: + f_back next outer frame object (this frame's caller) + f_builtins built-in namespace seen by this frame + f_code code object being executed in this frame + f_exc_traceback traceback if raised in this frame, or None + f_exc_type exception type if raised in this frame, or None + f_exc_value exception value if raised in this frame, or None + f_globals global namespace seen by this frame + f_lasti index of last attempted instruction in bytecode + f_lineno current line number in Python source code + f_locals local namespace seen by this frame + f_restricted 0 or 1 if frame is in restricted execution mode + f_trace tracing function for this frame, or None""" + return isinstance(object, types.FrameType) + +def iscode(object): + """Return true if the object is a code object. + + Code objects provide these attributes: + co_argcount number of arguments (not including * or ** args) + co_code string of raw compiled bytecode + co_consts tuple of constants used in the bytecode + co_filename name of file in which this code object was created + co_firstlineno number of first line in Python source code + co_flags bitmap: 1=optimized | 2=newlocals | 4=*arg | 8=**arg + co_lnotab encoded mapping of line numbers to bytecode indices + co_name name with which this code object was defined + co_names tuple of names of local variables + co_nlocals number of local variables + co_stacksize virtual machine stack space required + co_varnames tuple of names of arguments and local variables""" + return isinstance(object, types.CodeType) + +def isbuiltin(object): + """Return true if the object is a built-in function or method. + + Built-in functions and methods provide these attributes: + __doc__ documentation string + __name__ original name of this function or method + __self__ instance to which a method is bound, or None""" + return isinstance(object, types.BuiltinFunctionType) + +def isroutine(object): + """Return true if the object is any kind of function or method.""" + return (isbuiltin(object) + or isfunction(object) + or ismethod(object) + or ismethoddescriptor(object)) + +def getmembers(object, predicate=None): + """Return all members of an object as (name, value) pairs sorted by name. + Optionally, only return members that satisfy a given predicate.""" + results = [] + for key in dir(object): + value = getattr(object, key) + if not predicate or predicate(value): + results.append((key, value)) + results.sort() + return results + +def classify_class_attrs(cls): + """Return list of attribute-descriptor tuples. + + For each name in dir(cls), the return list contains a 4-tuple + with these elements: + + 0. The name (a string). + + 1. The kind of attribute this is, one of these strings: + 'class method' created via classmethod() + 'static method' created via staticmethod() + 'property' created via property() + 'method' any other flavor of method + 'data' not a method + + 2. The class which defined this attribute (a class). + + 3. The object as obtained directly from the defining class's + __dict__, not via getattr. This is especially important for + data attributes: C.data is just a data object, but + C.__dict__['data'] may be a data descriptor with additional + info, like a __doc__ string. + """ + + mro = getmro(cls) + names = dir(cls) + result = [] + for name in names: + # Get the object associated with the name. + # Getting an obj from the __dict__ sometimes reveals more than + # using getattr. Static and class methods are dramatic examples. + if name in cls.__dict__: + obj = cls.__dict__[name] + else: + obj = getattr(cls, name) + + # Figure out where it was defined. + homecls = getattr(obj, "__objclass__", None) + if homecls is None: + # search the dicts. + for base in mro: + if name in base.__dict__: + homecls = base + break + + # Get the object again, in order to get it from the defining + # __dict__ instead of via getattr (if possible). + if homecls is not None and name in homecls.__dict__: + obj = homecls.__dict__[name] + + # Also get the object via getattr. + obj_via_getattr = getattr(cls, name) + + # Classify the object. + if isinstance(obj, staticmethod): + kind = "static method" + elif isinstance(obj, classmethod): + kind = "class method" + elif isinstance(obj, property): + kind = "property" + elif (ismethod(obj_via_getattr) or + ismethoddescriptor(obj_via_getattr)): + kind = "method" + else: + kind = "data" + + result.append((name, kind, homecls, obj)) + + return result + +# ----------------------------------------------------------- class helpers +def _searchbases(cls, accum): + # Simulate the "classic class" search order. + if cls in accum: + return + accum.append(cls) + for base in cls.__bases__: + _searchbases(base, accum) + +def getmro(cls): + "Return tuple of base classes (including cls) in method resolution order." + if hasattr(cls, "__mro__"): + return cls.__mro__ + else: + result = [] + _searchbases(cls, result) + return tuple(result) + +# -------------------------------------------------- source code extraction +def indentsize(line): + """Return the indent size, in spaces, at the start of a line of text.""" + expline = string.expandtabs(line) + return len(expline) - len(string.lstrip(expline)) + +def getdoc(object): + """Get the documentation string for an object. + + All tabs are expanded to spaces. To clean up docstrings that are + indented to line up with blocks of code, any whitespace than can be + uniformly removed from the second line onwards is removed.""" + try: + doc = object.__doc__ + except AttributeError: + return None + if not isinstance(doc, (str, unicode)): + return None + try: + lines = string.split(string.expandtabs(doc), '\n') + except UnicodeError: + return None + else: + margin = None + for line in lines[1:]: + content = len(string.lstrip(line)) + if not content: continue + indent = len(line) - content + if margin is None: margin = indent + else: margin = min(margin, indent) + if margin is not None: + for i in range(1, len(lines)): lines[i] = lines[i][margin:] + return string.join(lines, '\n') + +def getfile(object): + """Work out which source or compiled file an object was defined in.""" + if ismodule(object): + if hasattr(object, '__file__'): + return object.__file__ + raise TypeError, 'arg is a built-in module' + if isclass(object): + object = sys.modules.get(object.__module__) + if hasattr(object, '__file__'): + return object.__file__ + raise TypeError, 'arg is a built-in class' + if ismethod(object): + object = object.im_func + if isfunction(object): + object = object.func_code + if istraceback(object): + object = object.tb_frame + if isframe(object): + object = object.f_code + if iscode(object): + return object.co_filename + raise TypeError, 'arg is not a module, class, method, ' \ + 'function, traceback, frame, or code object' + +def getmoduleinfo(path): + """Get the module name, suffix, mode, and module type for a given file.""" + filename = os.path.basename(path) + suffixes = map(lambda (suffix, mode, mtype): + (-len(suffix), suffix, mode, mtype), imp.get_suffixes()) + suffixes.sort() # try longest suffixes first, in case they overlap + for neglen, suffix, mode, mtype in suffixes: + if filename[neglen:] == suffix: + return filename[:neglen], suffix, mode, mtype + +def getmodulename(path): + """Return the module name for a given file, or None.""" + info = getmoduleinfo(path) + if info: return info[0] + +def getsourcefile(object): + """Return the Python source file an object was defined in, if it exists.""" + filename = getfile(object) + if string.lower(filename[-4:]) in ['.pyc', '.pyo']: + filename = filename[:-4] + '.py' + for suffix, mode, kind in imp.get_suffixes(): + if 'b' in mode and string.lower(filename[-len(suffix):]) == suffix: + # Looks like a binary file. We want to only return a text file. + return None + if os.path.exists(filename): + return filename + +def getabsfile(object): + """Return an absolute path to the source or compiled file for an object. + + The idea is for each object to have a unique origin, so this routine + normalizes the result as much as possible.""" + return os.path.normcase( + os.path.abspath(getsourcefile(object) or getfile(object))) + +modulesbyfile = {} + +def getmodule(object): + """Return the module an object was defined in, or None if not found.""" + if ismodule(object): + return object + if isclass(object): + return sys.modules.get(object.__module__) + try: + file = getabsfile(object) + except TypeError: + return None + if modulesbyfile.has_key(file): + return sys.modules[modulesbyfile[file]] + for module in sys.modules.values(): + if hasattr(module, '__file__'): + modulesbyfile[getabsfile(module)] = module.__name__ + if modulesbyfile.has_key(file): + return sys.modules[modulesbyfile[file]] + main = sys.modules['__main__'] + if hasattr(main, object.__name__): + mainobject = getattr(main, object.__name__) + if mainobject is object: + return main + builtin = sys.modules['__builtin__'] + if hasattr(builtin, object.__name__): + builtinobject = getattr(builtin, object.__name__) + if builtinobject is object: + return builtin + +def findsource(object): + """Return the entire source file and starting line number for an object. + + The argument may be a module, class, method, function, traceback, frame, + or code object. The source code is returned as a list of all the lines + in the file and the line number indexes a line in that list. An IOError + is raised if the source code cannot be retrieved.""" + try: + file = open(getsourcefile(object)) + except (TypeError, IOError): + raise IOError, 'could not get source code' + lines = file.readlines() + file.close() + + if ismodule(object): + return lines, 0 + + if isclass(object): + name = object.__name__ + pat = re.compile(r'^\s*class\s*' + name + r'\b') + for i in range(len(lines)): + if pat.match(lines[i]): return lines, i + else: raise IOError, 'could not find class definition' + + if ismethod(object): + object = object.im_func + if isfunction(object): + object = object.func_code + if istraceback(object): + object = object.tb_frame + if isframe(object): + object = object.f_code + if iscode(object): + if not hasattr(object, 'co_firstlineno'): + raise IOError, 'could not find function definition' + lnum = object.co_firstlineno - 1 + pat = re.compile(r'^(\s*def\s)|(.*\slambda(:|\s))') + while lnum > 0: + if pat.match(lines[lnum]): break + lnum = lnum - 1 + return lines, lnum + raise IOError, 'could not find code object' + +def getcomments(object): + """Get lines of comments immediately preceding an object's source code.""" + try: lines, lnum = findsource(object) + except IOError: return None + + if ismodule(object): + # Look for a comment block at the top of the file. + start = 0 + if lines and lines[0][:2] == '#!': start = 1 + while start < len(lines) and string.strip(lines[start]) in ['', '#']: + start = start + 1 + if start < len(lines) and lines[start][:1] == '#': + comments = [] + end = start + while end < len(lines) and lines[end][:1] == '#': + comments.append(string.expandtabs(lines[end])) + end = end + 1 + return string.join(comments, '') + + # Look for a preceding block of comments at the same indentation. + elif lnum > 0: + indent = indentsize(lines[lnum]) + end = lnum - 1 + if end >= 0 and string.lstrip(lines[end])[:1] == '#' and \ + indentsize(lines[end]) == indent: + comments = [string.lstrip(string.expandtabs(lines[end]))] + if end > 0: + end = end - 1 + comment = string.lstrip(string.expandtabs(lines[end])) + while comment[:1] == '#' and indentsize(lines[end]) == indent: + comments[:0] = [comment] + end = end - 1 + if end < 0: break + comment = string.lstrip(string.expandtabs(lines[end])) + while comments and string.strip(comments[0]) == '#': + comments[:1] = [] + while comments and string.strip(comments[-1]) == '#': + comments[-1:] = [] + return string.join(comments, '') + +class ListReader: + """Provide a readline() method to return lines from a list of strings.""" + def __init__(self, lines): + self.lines = lines + self.index = 0 + + def readline(self): + i = self.index + if i < len(self.lines): + self.index = i + 1 + return self.lines[i] + else: return '' + +class EndOfBlock(Exception): pass + +class BlockFinder: + """Provide a tokeneater() method to detect the end of a code block.""" + def __init__(self): + self.indent = 0 + self.started = 0 + self.last = 0 + + def tokeneater(self, type, token, (srow, scol), (erow, ecol), line): + if not self.started: + if type == tokenize.NAME: self.started = 1 + elif type == tokenize.NEWLINE: + self.last = srow + elif type == tokenize.INDENT: + self.indent = self.indent + 1 + elif type == tokenize.DEDENT: + self.indent = self.indent - 1 + if self.indent == 0: raise EndOfBlock, self.last + elif type == tokenize.NAME and scol == 0: + raise EndOfBlock, self.last + +def getblock(lines): + """Extract the block of code at the top of the given list of lines.""" + try: + tokenize.tokenize(ListReader(lines).readline, BlockFinder().tokeneater) + except EndOfBlock, eob: + return lines[:eob.args[0]] + # Fooling the indent/dedent logic implies a one-line definition + return lines[:1] + +def getsourcelines(object): + """Return a list of source lines and starting line number for an object. + + The argument may be a module, class, method, function, traceback, frame, + or code object. The source code is returned as a list of the lines + corresponding to the object and the line number indicates where in the + original source file the first line of code was found. An IOError is + raised if the source code cannot be retrieved.""" + lines, lnum = findsource(object) + + if ismodule(object): return lines, 0 + else: return getblock(lines[lnum:]), lnum + 1 + +def getsource(object): + """Return the text of the source code for an object. + + The argument may be a module, class, method, function, traceback, frame, + or code object. The source code is returned as a single string. An + IOError is raised if the source code cannot be retrieved.""" + lines, lnum = getsourcelines(object) + return string.join(lines, '') + +# --------------------------------------------------- class tree extraction +def walktree(classes, children, parent): + """Recursive helper function for getclasstree().""" + results = [] + classes.sort(lambda a, b: cmp(a.__name__, b.__name__)) + for c in classes: + results.append((c, c.__bases__)) + if children.has_key(c): + results.append(walktree(children[c], children, c)) + return results + +def getclasstree(classes, unique=0): + """Arrange the given list of classes into a hierarchy of nested lists. + + Where a nested list appears, it contains classes derived from the class + whose entry immediately precedes the list. Each entry is a 2-tuple + containing a class and a tuple of its base classes. If the 'unique' + argument is true, exactly one entry appears in the returned structure + for each class in the given list. Otherwise, classes using multiple + inheritance and their descendants will appear multiple times.""" + children = {} + roots = [] + for c in classes: + if c.__bases__: + for parent in c.__bases__: + if not children.has_key(parent): + children[parent] = [] + children[parent].append(c) + if unique and parent in classes: break + elif c not in roots: + roots.append(c) + for parent in children.keys(): + if parent not in classes: + roots.append(parent) + return walktree(roots, children, None) + +# ------------------------------------------------ argument list extraction +# These constants are from Python's compile.h. +CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS = 1, 2, 4, 8 + +def getargs(co): + """Get information about the arguments accepted by a code object. + + Three things are returned: (args, varargs, varkw), where 'args' is + a list of argument names (possibly containing nested lists), and + 'varargs' and 'varkw' are the names of the * and ** arguments or None.""" + if not iscode(co): raise TypeError, 'arg is not a code object' + + code = co.co_code + nargs = co.co_argcount + names = co.co_varnames + args = list(names[:nargs]) + step = 0 + + # The following acrobatics are for anonymous (tuple) arguments. + for i in range(nargs): + if args[i][:1] in ['', '.']: + stack, remain, count = [], [], [] + while step < len(code): + op = ord(code[step]) + step = step + 1 + if op >= dis.HAVE_ARGUMENT: + opname = dis.opname[op] + value = ord(code[step]) + ord(code[step+1])*256 + step = step + 2 + if opname in ['UNPACK_TUPLE', 'UNPACK_SEQUENCE']: + remain.append(value) + count.append(value) + elif opname == 'STORE_FAST': + stack.append(names[value]) + remain[-1] = remain[-1] - 1 + while remain[-1] == 0: + remain.pop() + size = count.pop() + stack[-size:] = [stack[-size:]] + if not remain: break + remain[-1] = remain[-1] - 1 + if not remain: break + args[i] = stack[0] + + varargs = None + if co.co_flags & CO_VARARGS: + varargs = co.co_varnames[nargs] + nargs = nargs + 1 + varkw = None + if co.co_flags & CO_VARKEYWORDS: + varkw = co.co_varnames[nargs] + return args, varargs, varkw + +def getargspec(func): + """Get the names and default values of a function's arguments. + + A tuple of four things is returned: (args, varargs, varkw, defaults). + 'args' is a list of the argument names (it may contain nested lists). + 'varargs' and 'varkw' are the names of the * and ** arguments or None. + 'defaults' is an n-tuple of the default values of the last n arguments.""" + if ismethod(func): + func = func.im_func + if not isfunction(func): raise TypeError, 'arg is not a Python function' + args, varargs, varkw = getargs(func.func_code) + return args, varargs, varkw, func.func_defaults + +def getargvalues(frame): + """Get information about arguments passed into a particular frame. + + A tuple of four things is returned: (args, varargs, varkw, locals). + 'args' is a list of the argument names (it may contain nested lists). + 'varargs' and 'varkw' are the names of the * and ** arguments or None. + 'locals' is the locals dictionary of the given frame.""" + args, varargs, varkw = getargs(frame.f_code) + return args, varargs, varkw, frame.f_locals + +def joinseq(seq): + if len(seq) == 1: + return '(' + seq[0] + ',)' + else: + return '(' + string.join(seq, ', ') + ')' + +def strseq(object, convert, join=joinseq): + """Recursively walk a sequence, stringifying each element.""" + if type(object) in [types.ListType, types.TupleType]: + return join(map(lambda o, c=convert, j=join: strseq(o, c, j), object)) + else: + return convert(object) + +def formatargspec(args, varargs=None, varkw=None, defaults=None, + formatarg=str, + formatvarargs=lambda name: '*' + name, + formatvarkw=lambda name: '**' + name, + formatvalue=lambda value: '=' + repr(value), + join=joinseq): + """Format an argument spec from the 4 values returned by getargspec. + + The first four arguments are (args, varargs, varkw, defaults). The + other four arguments are the corresponding optional formatting functions + that are called to turn names and values into strings. The ninth + argument is an optional function to format the sequence of arguments.""" + specs = [] + if defaults: + firstdefault = len(args) - len(defaults) + for i in range(len(args)): + spec = strseq(args[i], formatarg, join) + if defaults and i >= firstdefault: + spec = spec + formatvalue(defaults[i - firstdefault]) + specs.append(spec) + if varargs: + specs.append(formatvarargs(varargs)) + if varkw: + specs.append(formatvarkw(varkw)) + return '(' + string.join(specs, ', ') + ')' + +def formatargvalues(args, varargs, varkw, locals, + formatarg=str, + formatvarargs=lambda name: '*' + name, + formatvarkw=lambda name: '**' + name, + formatvalue=lambda value: '=' + repr(value), + join=joinseq): + """Format an argument spec from the 4 values returned by getargvalues. + + The first four arguments are (args, varargs, varkw, locals). The + next four arguments are the corresponding optional formatting functions + that are called to turn names and values into strings. The ninth + argument is an optional function to format the sequence of arguments.""" + def convert(name, locals=locals, + formatarg=formatarg, formatvalue=formatvalue): + return formatarg(name) + formatvalue(locals[name]) + specs = [] + for i in range(len(args)): + specs.append(strseq(args[i], convert, join)) + if varargs: + specs.append(formatvarargs(varargs) + formatvalue(locals[varargs])) + if varkw: + specs.append(formatvarkw(varkw) + formatvalue(locals[varkw])) + return '(' + string.join(specs, ', ') + ')' + +# -------------------------------------------------- stack frame extraction +def getframeinfo(frame, context=1): + """Get information about a frame or traceback object. + + A tuple of five things is returned: the filename, the line number of + the current line, the function name, a list of lines of context from + the source code, and the index of the current line within that list. + The optional second argument specifies the number of lines of context + to return, which are centered around the current line.""" + if istraceback(frame): + frame = frame.tb_frame + if not isframe(frame): + raise TypeError, 'arg is not a frame or traceback object' + + filename = getsourcefile(frame) + lineno = getlineno(frame) + if context > 0: + start = lineno - 1 - context//2 + try: + lines, lnum = findsource(frame) + except IOError: + lines = index = None + else: + start = max(start, 1) + start = min(start, len(lines) - context) + lines = lines[start:start+context] + index = lineno - 1 - start + else: + lines = index = None + + return (filename, lineno, frame.f_code.co_name, lines, index) + +def getlineno(frame): + """Get the line number from a frame object, allowing for optimization.""" + # Written by Marc-Andr\xE9 Lemburg; revised by Jim Hugunin and Fredrik Lundh. + lineno = frame.f_lineno + code = frame.f_code + if hasattr(code, 'co_lnotab'): + table = code.co_lnotab + lineno = code.co_firstlineno + addr = 0 + for i in range(0, len(table), 2): + addr = addr + ord(table[i]) + if addr > frame.f_lasti: break + lineno = lineno + ord(table[i+1]) + return lineno + +def getouterframes(frame, context=1): + """Get a list of records for a frame and all higher (calling) frames. + + Each record contains a frame object, filename, line number, function + name, a list of lines of context, and index within the context.""" + framelist = [] + while frame: + framelist.append((frame,) + getframeinfo(frame, context)) + frame = frame.f_back + return framelist + +def getinnerframes(tb, context=1): + """Get a list of records for a traceback's frame and all lower frames. + + Each record contains a frame object, filename, line number, function + name, a list of lines of context, and index within the context.""" + framelist = [] + while tb: + framelist.append((tb.tb_frame,) + getframeinfo(tb, context)) + tb = tb.tb_next + return framelist + +def currentframe(): + """Return the frame object for the caller's stack frame.""" + try: + raise 'catch me' + except: + return sys.exc_traceback.tb_frame.f_back + +if hasattr(sys, '_getframe'): currentframe = sys._getframe + +def stack(context=1): + """Return a list of records for the stack above the caller's frame.""" + return getouterframes(currentframe().f_back, context) + +def trace(context=1): + """Return a list of records for the stack below the current exception.""" + return getinnerframes(sys.exc_traceback, context) Modified: trunk/jython/Lib/test/test_inspect.py =================================================================== --- trunk/jython/Lib/test/test_inspect.py 2007-04-29 08:59:39 UTC (rev 3198) +++ trunk/jython/Lib/test/test_inspect.py 2007-04-30 06:45:58 UTC (rev 3199) @@ -1,6 +1,3 @@ -#FIXME: removed all of the tests that don't really apply to Jython's version of -# inspect.py. However, some of the missing functionality in inspect -# is implementable -- so check back here if that is done. source = '''# line 1 'A module docstring.' @@ -46,7 +43,7 @@ spam(a, b, c) except: self.ex = sys.exc_info() - #self.tr = inspect.trace() + self.tr = inspect.trace() # line 48 class MalodorousPervert(StupidGit): @@ -85,15 +82,9 @@ def istest(func, exp): obj = eval(exp) test(func(obj), '%s(%s)' % (func.__name__, exp)) - for other in [#inspect.isbuiltin, - inspect.isclass, - #inspect.iscode, - #inspect.isframe, - inspect.isfunction, - inspect.ismethod, - inspect.ismodule, - #inspect.istraceback - ]: + for other in [inspect.isbuiltin, inspect.isclass, inspect.iscode, + inspect.isframe, inspect.isfunction, inspect.ismethod, + inspect.ismodule, inspect.istraceback]: if other is not func: test(not other(obj), 'not %s(%s)' % (other.__name__, exp)) @@ -103,119 +94,119 @@ except: tb = sys.exc_traceback -#istest(inspect.isbuiltin, 'sys.exit') -#istest(inspect.isbuiltin, '[].append') +istest(inspect.isbuiltin, 'sys.exit') +istest(inspect.isbuiltin, '[].append') istest(inspect.isclass, 'mod.StupidGit') -#istest(inspect.iscode, 'mod.spam.func_code') -#istest(inspect.isframe, 'tb.tb_frame') +istest(inspect.iscode, 'mod.spam.func_code') +istest(inspect.isframe, 'tb.tb_frame') istest(inspect.isfunction, 'mod.spam') istest(inspect.ismethod, 'mod.StupidGit.abuse') istest(inspect.ismethod, 'git.argue') istest(inspect.ismodule, 'mod') -#istest(inspect.istraceback, 'tb') -#test(inspect.isroutine(mod.spam), 'isroutine(mod.spam)') -#test(inspect.isroutine([].count), 'isroutine([].count)') +istest(inspect.istraceback, 'tb') +test(inspect.isroutine(mod.spam), 'isroutine(mod.spam)') +test(inspect.isroutine([].count), 'isroutine([].count)') -#classes = inspect.getmembers(mod, inspect.isclass) -#test(classes == -# [('FesteringGob', mod.FesteringGob), -# ('MalodorousPervert', mod.MalodorousPervert), -# ('ParrotDroppings', mod.ParrotDroppings), -# ('StupidGit', mod.StupidGit)], 'class list') -#tree = inspect.getclasstree(map(lambda x: x[1], classes), 1) -#test(tree == -# [(mod.ParrotDroppings, ()), -# (mod.StupidGit, ()), -# [(mod.MalodorousPervert, (mod.StupidGit,)), -# [(mod.FesteringGob, (mod.MalodorousPervert, mod.ParrotDroppings)) -# ] -# ] -# ], 'class tree') +classes = inspect.getmembers(mod, inspect.isclass) +test(classes == + [('FesteringGob', mod.FesteringGob), + ('MalodorousPervert', mod.MalodorousPervert), + ('ParrotDroppings', mod.ParrotDroppings), + ('StupidGit', mod.StupidGit)], 'class list') +tree = inspect.getclasstree(map(lambda x: x[1], classes), 1) +test(tree == + [(mod.ParrotDroppings, ()), + (mod.StupidGit, ()), + [(mod.MalodorousPervert, (mod.StupidGit,)), + [(mod.FesteringGob, (mod.MalodorousPervert, mod.ParrotDroppings)) + ] + ] + ], 'class tree') -#functions = inspect.getmembers(mod, inspect.isfunction) -#test(functions == [('eggs', mod.eggs), ('spam', mod.spam)], 'function list') +functions = inspect.getmembers(mod, inspect.isfunction) +test(functions == [('eggs', mod.eggs), ('spam', mod.spam)], 'function list') -#test(inspect.getdoc(mod) == 'A module docstring.', 'getdoc(mod)') -#test(inspect.getcomments(mod) == '# line 1\n', 'getcomments(mod)') -#test(inspect.getmodule(mod.StupidGit) == mod, 'getmodule(mod.StupidGit)') -#test(inspect.getfile(mod.StupidGit) == TESTFN, 'getfile(mod.StupidGit)') -#test(inspect.getsourcefile(mod.spam) == TESTFN, 'getsourcefile(mod.spam)') -#test(inspect.getsourcefile(git.abuse) == TESTFN, 'getsourcefile(git.abuse)') +test(inspect.getdoc(mod) == 'A module docstring.', 'getdoc(mod)') +test(inspect.getcomments(mod) == '# line 1\n', 'getcomments(mod)') +test(inspect.getmodule(mod.StupidGit) == mod, 'getmodule(mod.StupidGit)') +test(inspect.getfile(mod.StupidGit) == TESTFN, 'getfile(mod.StupidGit)') +test(inspect.getsourcefile(mod.spam) == TESTFN, 'getsourcefile(mod.spam)') +test(inspect.getsourcefile(git.abuse) == TESTFN, 'getsourcefile(git.abuse)') def sourcerange(top, bottom): lines = string.split(source, '\n') return string.join(lines[top-1:bottom], '\n') + '\n' -#test(inspect.getsource(git.abuse) == sourcerange(29, 39), -# 'getsource(git.abuse)') -#test(inspect.getsource(mod.StupidGit) == sourcerange(21, 46), -# 'getsource(mod.StupidGit)') -#test(inspect.getdoc(mod.StupidGit) == -# 'A longer,\n\nindented\n\ndocstring.', 'getdoc(mod.StupidGit)') -#test(inspect.getdoc(git.abuse) == -# 'Another\n\ndocstring\n\ncontaining\n\ntabs\n\n', 'getdoc(git.abuse)') -#test(inspect.getcomments(mod.StupidGit) == '# line 20\n', -# 'getcomments(mod.StupidGit)') +test(inspect.getsource(git.abuse) == sourcerange(29, 39), + 'getsource(git.abuse)') +test(inspect.getsource(mod.StupidGit) == sourcerange(21, 46), + 'getsource(mod.StupidGit)') +test(inspect.getdoc(mod.StupidGit) == + 'A longer,\n\nindented\n\ndocstring.', 'getdoc(mod.StupidGit)') +test(inspect.getdoc(git.abuse) == + 'Another\n\ndocstring\n\ncontaining\n\ntabs\n\n', 'getdoc(git.abuse)') +test(inspect.getcomments(mod.StupidGit) == '# line 20\n', + 'getcomments(mod.StupidGit)') -#args, varargs, varkw, defaults = inspect.getargspec(mod.eggs) -#test(args == ['x', 'y'], 'mod.eggs args') -#test(varargs == None, 'mod.eggs varargs') -#test(varkw == None, 'mod.eggs varkw') -#test(defaults == None, 'mod.eggs defaults') -#test(inspect.formatargspec(args, varargs, varkw, defaults) == -# '(x, y)', 'mod.eggs formatted argspec') -#args, varargs, varkw, defaults = inspect.getargspec(mod.spam) -#test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.spam args') -#test(varargs == 'g', 'mod.spam varargs') -#test(varkw == 'h', 'mod.spam varkw') -#test(defaults == (3, (4, (5,))), 'mod.spam defaults') -#test(inspect.formatargspec(args, varargs, varkw, defaults) == -# '(a, b, c, d=3, (e, (f,))=(4, (5,)), *g, **h)', -# 'mod.spam formatted argspec') +args, varargs, varkw, defaults = inspect.getargspec(mod.eggs) +test(args == ['x', 'y'], 'mod.eggs args') +test(varargs == None, 'mod.eggs varargs') +test(varkw == None, 'mod.eggs varkw') +test(defaults == None, 'mod.eggs defaults') +test(inspect.formatargspec(args, varargs, varkw, defaults) == + '(x, y)', 'mod.eggs formatted argspec') +args, varargs, varkw, defaults = inspect.getargspec(mod.spam) +test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.spam args') +test(varargs == 'g', 'mod.spam varargs') +test(varkw == 'h', 'mod.spam varkw') +test(defaults == (3, (4, (5,))), 'mod.spam defaults') +test(inspect.formatargspec(args, varargs, varkw, defaults) == + '(a, b, c, d=3, (e, (f,))=(4, (5,)), *g, **h)', + 'mod.spam formatted argspec') git.abuse(7, 8, 9) -#istest(inspect.istraceback, 'git.ex[2]') -#istest(inspect.isframe, 'mod.fr') +istest(inspect.istraceback, 'git.ex[2]') +istest(inspect.isframe, 'mod.fr') -#test(len(git.tr) == 3, 'trace() length') -#test(git.tr[0][1:] == (TESTFN, 46, 'argue', -# [' self.tr = inspect.trace()\n'], 0), -# 'trace() row 2') -#test(git.tr[1][1:] == (TESTFN, 9, 'spam', [' eggs(b + d, c + f)\n'], 0), -# 'trace() row 2') -#test(git.tr[2][1:] == (TESTFN, 18, 'eggs', [' q = y / 0\n'], 0), -# 'trace() row 3') +test(len(git.tr) == 3, 'trace() length') +test(git.tr[0][1:] == (TESTFN, 46, 'argue', + [' self.tr = inspect.trace()\n'], 0), + 'trace() row 2') +test(git.tr[1][1:] == (TESTFN, 9, 'spam', [' eggs(b + d, c + f)\n'], 0), + 'trace() row 2') +test(git.tr[2][1:] == (TESTFN, 18, 'eggs', [' q = y / 0\n'], 0), + 'trace() row 3') -#test(len(mod.st) >= 5, 'stack() length') -#test(mod.st[0][1:] == -# (TESTFN, 16, 'eggs', [' st = inspect.stack()\n'], 0), -# 'stack() row 1') -#test(mod.st[1][1:] == -# (TESTFN, 9, 'spam', [' eggs(b + d, c + f)\n'], 0), -# 'stack() row 2') -#test(mod.st[2][1:] == -# (TESTFN, 43, 'argue', [' spam(a, b, c)\n'], 0), -# 'stack() row 3') -#test(mod.st[3][1:] == -# (TESTFN, 39, 'abuse', [' self.argue(a, b, c)\n'], 0), -# 'stack() row 4') +test(len(mod.st) >= 5, 'stack() length') +test(mod.st[0][1:] == + (TESTFN, 16, 'eggs', [' st = inspect.stack()\n'], 0), + 'stack() row 1') +test(mod.st[1][1:] == + (TESTFN, 9, 'spam', [' eggs(b + d, c + f)\n'], 0), + 'stack() row 2') +test(mod.st[2][1:] == + (TESTFN, 43, 'argue', [' spam(a, b, c)\n'], 0), + 'stack() row 3') +test(mod.st[3][1:] == + (TESTFN, 39, 'abuse', [' self.argue(a, b, c)\n'], 0), + 'stack() row 4') -#args, varargs, varkw, locals = inspect.getargvalues(mod.fr) -#test(args == ['x', 'y'], 'mod.fr args') -#test(varargs == None, 'mod.fr varargs') -#test(varkw == None, 'mod.fr varkw') -#test(locals == {'x': 11, 'p': 11, 'y': 14}, 'mod.fr locals') -#test(inspect.formatargvalues(args, varargs, varkw, locals) == -# '(x=11, y=14)', 'mod.fr formatted argvalues') +args, varargs, varkw, locals = inspect.getargvalues(mod.fr) +test(args == ['x', 'y'], 'mod.fr args') +test(varargs == None, 'mod.fr varargs') +test(varkw == None, 'mod.fr varkw') +test(locals == {'x': 11, 'p': 11, 'y': 14}, 'mod.fr locals') +test(inspect.formatargvalues(args, varargs, varkw, locals) == + '(x=11, y=14)', 'mod.fr formatted argvalues') -#args, varargs, varkw, locals = inspect.getargvalues(mod.fr.f_back) -#test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.fr.f_back args') -#test(varargs == 'g', 'mod.fr.f_back varargs') -#test(varkw == 'h', 'mod.fr.f_back varkw') -#test(inspect.formatargvalues(args, varargs, varkw, locals) == -# '(a=7, b=8, c=9, d=3, (e=4, (f=5,)), *g=(), **h={})', -# 'mod.fr.f_back formatted argvalues') +args, varargs, varkw, locals = inspect.getargvalues(mod.fr.f_back) +test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.fr.f_back args') +test(varargs == 'g', 'mod.fr.f_back varargs') +test(varkw == 'h', 'mod.fr.f_back varkw') +test(inspect.formatargvalues(args, varargs, varkw, locals) == + '(a=7, b=8, c=9, d=3, (e=4, (f=5,)), *g=(), **h={})', + 'mod.fr.f_back formatted argvalues') for fname in files_to_clean_up: try: @@ -230,8 +221,8 @@ class D(B, C): pass expected = (D, B, A, C) -#got = inspect.getmro(D) -#test(expected == got, "expected %r mro, got %r", expected, got) +got = inspect.getmro(D) +test(expected == got, "expected %r mro, got %r", expected, got) # The same w/ new-class MRO. class A(object): pass @@ -240,8 +231,8 @@ class D(B, C): pass expected = (D, B, C, A, object) -#got = inspect.getmro(D) -#test(expected == got, "expected %r mro, got %r", expected, got) +got = inspect.getmro(D) +test(expected == got, "expected %r mro, got %r", expected, got) # Test classify_class_attrs. def attrs_wo_objs(cls): @@ -263,110 +254,110 @@ datablob = '1' -#attrs = attrs_wo_objs(A) -#test(('s', 'static method', A) in attrs, 'missing static method') -#test(('c', 'class method', A) in attrs, 'missing class method') -#test(('p', 'property', A) in attrs, 'missing property') -#test(('m', 'method', A) in attrs, 'missing plain method') -#test(('m1', 'method', A) in attrs, 'missing plain method') -#test(('datablob', 'data', A) in attrs, 'missing data') +attrs = attrs_wo_objs(A) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', A) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') -#class B(A): -# def m(self): pass +class B(A): + def m(self): pass -#attrs = attrs_wo_objs(B) -#test(('s', 'static method', A) in attrs, 'missing static method') -#test(('c', 'class method', A) in attrs, 'missing class method') -#test(('p', 'property', A) in attrs, 'missing property') -#test(('m', 'method', B) in attrs, 'missing plain method') -#test(('m1', 'method', A) in attrs, 'missing plain method') -#test(('datablob', 'data', A) in attrs, 'missing data') +attrs = attrs_wo_objs(B) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', B) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') -#class C(A): -# def m(self): pass -# def c(self): pass +class C(A): + def m(self): pass + def c(self): pass -#attrs = attrs_wo_objs(C) -#test(('s', 'static method', A) in attrs, 'missing static method') -#test(('c', 'method', C) in attrs, 'missing plain method') -#test(('p', 'property', A) in attrs, 'missing property') -#test(('m', 'method', C) in attrs, 'missing plain method') -#test(('m1', 'method', A) in attrs, 'missing plain method') -#test(('datablob', 'data', A) in attrs, 'missing data') +attrs = attrs_wo_objs(C) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'method', C) in attrs, 'missing plain method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', C) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') -#class D(B, C): -# def m1(self): pass +class D(B, C): + def m1(self): pass -#attrs = attrs_wo_objs(D) -#test(('s', 'static method', A) in attrs, 'missing static method') -#test(('c', 'class method', A) in attrs, 'missing class method') -#test(('p', 'property', A) in attrs, 'missing property') -#test(('m', 'method', B) in attrs, 'missing plain method') -#test(('m1', 'method', D) in attrs, 'missing plain method') -#test(('datablob', 'data', A) in attrs, 'missing data') +attrs = attrs_wo_objs(D) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', B) in attrs, 'missing plain method') +test(('m1', 'method', D) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') # Repeat all that, but w/ new-style classes. -#class A(object): +class A(object): -# def s(): pass -# s = staticmethod(s) + def s(): pass + s = staticmethod(s) -# def c(cls): pass -# c = classmethod(c) + def c(cls): pass + c = classmethod(c) -# def getp(self): pass -# p = property(getp) + def getp(self): pass + p = property(getp) -# def m(self): pass + def m(self): pass -# def m1(self): pass + def m1(self): pass -# datablob = '1' + datablob = '1' -#attrs = attrs_wo_objs(A) -#test(('s', 'static method', A) in attrs, 'missing static method') -#test(('c', 'class method', A) in attrs, 'missing class method') -#test(('p', 'property', A) in attrs, 'missing property') -#test(('m', 'method', A) in attrs, 'missing plain method') -#test(('m1', 'method', A) in attrs, 'missing plain method') -#test(('datablob', 'data', A) in attrs, 'missing data') +attrs = attrs_wo_objs(A) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', A) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') -#class B(A): +class B(A): -# def m(self): pass + def m(self): pass -#attrs = attrs_wo_objs(B) -#test(('s', 'static method', A) in attrs, 'missing static method') -#test(('c', 'class method', A) in attrs, 'missing class method') -#test(('p', 'property', A) in attrs, 'missing property') -#test(('m', 'method', B) in attrs, 'missing plain method') -#test(('m1', 'method', A) in attrs, 'missing plain method') -#test(('datablob', 'data', A) in attrs, 'missing data') +attrs = attrs_wo_objs(B) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', B) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') -#class C(A): +class C(A): -# def m(self): pass -# def c(self): pass + def m(self): pass + def c(self): pass -#attrs = attrs_wo_objs(C) -#test(('s', 'static method', A) in attrs, 'missing static method') -#test(('c', 'method', C) in attrs, 'missing plain method') -#test(('p', 'property', A) in attrs, 'missing property') -#test(('m', 'method', C) in attrs, 'missing plain method') -#test(('m1', 'method', A) in attrs, 'missing plain method') -#test(('datablob', 'data', A) in attrs, 'missing data') +attrs = attrs_wo_objs(C) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'method', C) in attrs, 'missing plain method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', C) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') -#class D(B, C): +class D(B, C): -# def m1(self): pass + def m1(self): pass -#attrs = attrs_wo_objs(D) -#test(('s', 'static method', A) in attrs, 'missing static method') -#test(('c', 'method', C) in attrs, 'missing plain method') -#test(('p', 'property', A) in attrs, 'missing property') -#test(('m', 'method', B) in attrs, 'missing plain method') -#test(('m1', 'method', D) in attrs, 'missing plain method') -#test(('datablob', 'data', A) in attrs, 'missing data') +attrs = attrs_wo_objs(D) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'method', C) in attrs, 'missing plain method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', B) in attrs, 'missing plain method') +test(('m1', 'method', D) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <cg...@us...> - 2007-05-20 06:36:56
|
Revision: 3236 http://svn.sourceforge.net/jython/?rev=3236&view=rev Author: cgroves Date: 2007-05-19 23:36:55 -0700 (Sat, 19 May 2007) Log Message: ----------- Couple python.cachedir.skip=true fixes Modified Paths: -------------- trunk/jython/Lib/test/test_jbasic.py trunk/jython/Lib/zlib.py Modified: trunk/jython/Lib/test/test_jbasic.py =================================================================== --- trunk/jython/Lib/test/test_jbasic.py 2007-05-20 04:55:01 UTC (rev 3235) +++ trunk/jython/Lib/test/test_jbasic.py 2007-05-20 06:36:55 UTC (rev 3236) @@ -60,10 +60,10 @@ assert s.regionMatches(1, 1, 'eLl', 0, 3), 'method call ignore case' assert not s.regionMatches(1, 'eLl', 0, 3), 'should ignore case' -from java import awt +from java.awt import Dimension print_test('get/set fields') -d = awt.Dimension(3,9) +d = Dimension(3,9) assert d.width == 3 and d.height == 9, 'getting fields' d.width = 42 assert d.width == 42 and d.height == 9, 'setting fields' @@ -91,8 +91,10 @@ global flag flag = flag + 1 -doit = awt.event.ActionEvent(b1, awt.event.ActionEvent.ACTION_PERFORMED, "") +from java.awt.event import ActionEvent +doit = ActionEvent(b1, ActionEvent.ACTION_PERFORMED, "") + b1.actionPerformed = testAction flag = 0 b1.doClick() Modified: trunk/jython/Lib/zlib.py =================================================================== --- trunk/jython/Lib/zlib.py 2007-05-20 04:55:01 UTC (rev 3235) +++ trunk/jython/Lib/zlib.py 2007-05-20 06:36:55 UTC (rev 3236) @@ -1,7 +1,8 @@ - -from java import util, lang import jarray, binascii +from java.util.zip import Adler32, Deflater, Inflater +from java.lang import Long, String, StringBuffer + class error(Exception): pass @@ -29,9 +30,9 @@ def adler32(string, value=1): if value != 1: raise ValueError, "adler32 only support start value of 1" - checksum = util.zip.Adler32() - checksum.update(lang.String.getBytes(string)) - return lang.Long(checksum.getValue()).intValue() + checksum = Adler32() + checksum.update(String.getBytes(string)) + return Long(checksum.getValue()).intValue() def crc32(string, value=0): return binascii.crc32(string, value) @@ -40,13 +41,13 @@ def compress(string, level=6): if level < Z_BEST_SPEED or level > Z_BEST_COMPRESSION: raise error, "Bad compression level" - deflater = util.zip.Deflater(level, 0) + deflater = Deflater(level, 0) deflater.setInput(string, 0, len(string)) deflater.finish() return _get_deflate_data(deflater) def decompress(string, wbits=0, bufsize=16384): - inflater = util.zip.Inflater(wbits < 0) + inflater = Inflater(wbits < 0) inflater.setInput(string) return _get_inflate_data(inflater) @@ -57,7 +58,7 @@ memLevel=0, strategy=0): if abs(wbits) > MAX_WBITS or abs(wbits) < 8: raise ValueError, "Invalid initialization option" - self.deflater = util.zip.Deflater(level, wbits < 0) + self.deflater = Deflater(level, wbits < 0) self.deflater.setStrategy(strategy) if wbits < 0: _get_deflate_data(self.deflater) @@ -86,7 +87,7 @@ def __init__(self, wbits=MAX_WBITS): if abs(wbits) > MAX_WBITS or abs(wbits) < 8: raise ValueError, "Invalid initialization option" - self.inflater = util.zip.Inflater(wbits < 0) + self.inflater = Inflater(wbits < 0) self.unused_data = "" self._ended = False @@ -127,18 +128,18 @@ def _get_deflate_data(deflater): buf = jarray.zeros(1024, 'b') - sb = lang.StringBuffer() + sb = StringBuffer() while not deflater.finished(): l = deflater.deflate(buf) if l == 0: break - sb.append(lang.String(buf, 0, 0, l)) + sb.append(String(buf, 0, 0, l)) return sb.toString() def _get_inflate_data(inflater, max_length=0): buf = jarray.zeros(1024, 'b') - sb = lang.StringBuffer() + sb = StringBuffer() total = 0 while not inflater.finished(): if max_length: @@ -149,7 +150,7 @@ break total += l - sb.append(lang.String(buf, 0, 0, l)) + sb.append(String(buf, 0, 0, l)) if max_length and total == max_length: break return sb.toString() This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <otm...@us...> - 2007-05-28 11:33:49
|
Revision: 3244 http://svn.sourceforge.net/jython/?rev=3244&view=rev Author: otmarhumbel Date: 2007-05-28 04:33:44 -0700 (Mon, 28 May 2007) Log Message: ----------- use explicit imports (to allow standalone mode) Modified Paths: -------------- trunk/jython/Lib/isql.py trunk/jython/Lib/javaos.py trunk/jython/Lib/javapath.py trunk/jython/Lib/popen2.py trunk/jython/Lib/socket.py Modified: trunk/jython/Lib/isql.py =================================================================== --- trunk/jython/Lib/isql.py 2007-05-25 23:05:37 UTC (rev 3243) +++ trunk/jython/Lib/isql.py 2007-05-28 11:33:44 UTC (rev 3244) @@ -28,7 +28,7 @@ return prompt if os.name == 'java': def __tojava__(self, cls): - import java + import java.lang.String if cls == java.lang.String: return self.__str__() return False Modified: trunk/jython/Lib/javaos.py =================================================================== --- trunk/jython/Lib/javaos.py 2007-05-25 23:05:37 UTC (rev 3243) +++ trunk/jython/Lib/javaos.py 2007-05-28 11:33:44 UTC (rev 3244) @@ -26,8 +26,8 @@ "popen", "popen2", "popen3", "popen4", "getlogin" ] -import java from java.io import File +import java.lang.System import javapath as path from UserDict import UserDict Modified: trunk/jython/Lib/javapath.py =================================================================== --- trunk/jython/Lib/javapath.py 2007-05-25 23:05:37 UTC (rev 3243) +++ trunk/jython/Lib/javapath.py 2007-05-28 11:33:44 UTC (rev 3244) @@ -14,8 +14,8 @@ # sameopenfile -- Java doesn't have fstat nor file descriptors? # samestat -- How? -import java from java.io import File +import java.io.IOException from java.lang import System import os Modified: trunk/jython/Lib/popen2.py =================================================================== --- trunk/jython/Lib/popen2.py 2007-05-25 23:05:37 UTC (rev 3243) +++ trunk/jython/Lib/popen2.py 2007-05-28 11:33:44 UTC (rev 3244) @@ -22,8 +22,11 @@ import jarray from java.lang import System -from java.util import * -from java.io import * +from java.util import Vector +from java.io import BufferedOutputStream +from java.io import BufferedInputStream +from java.io import PipedOutputStream +from java.io import PipedInputStream from org.python.core import PyFile from javashell import shellexecute Modified: trunk/jython/Lib/socket.py =================================================================== --- trunk/jython/Lib/socket.py 2007-05-25 23:05:37 UTC (rev 3243) +++ trunk/jython/Lib/socket.py 2007-05-28 11:33:44 UTC (rev 3244) @@ -9,9 +9,15 @@ - 20050527: updated by Alan Kennedy to support socket timeouts. """ -import java.io -import java.net -import org.python.core +import java.io.InterruptedIOException +import java.net.DatagramSocket +import java.net.DatagramPacket +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.ServerSocket +import java.net.Socket +import java.net.SocketTimeoutException +import org.python.core.PyFile import jarray import string This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <cg...@us...> - 2007-06-18 06:17:11
|
Revision: 3256 http://svn.sourceforge.net/jython/?rev=3256&view=rev Author: cgroves Date: 2007-06-17 23:17:10 -0700 (Sun, 17 Jun 2007) Log Message: ----------- Copy Alan's new socket and select into trunk from the sandbox Added Paths: ----------- trunk/jython/Lib/select.py trunk/jython/Lib/socket.py trunk/jython/Lib/test/test_select.py trunk/jython/Lib/test/test_select_new.py trunk/jython/Lib/test/test_socket.py Removed Paths: ------------- trunk/jython/Lib/socket.py trunk/jython/Lib/test/test_socket.py Copied: trunk/jython/Lib/select.py (from rev 3255, trunk/sandbox/kennedya/asynch_sockets/select.py) =================================================================== --- trunk/jython/Lib/select.py (rev 0) +++ trunk/jython/Lib/select.py 2007-06-18 06:17:10 UTC (rev 3256) @@ -0,0 +1,157 @@ +""" +AMAK: 20070515: New select implementation that uses java.nio +""" + +import java.nio.channels.SelectableChannel +import java.nio.channels.SelectionKey +import java.nio.channels.Selector +from java.nio.channels.SelectionKey import OP_ACCEPT, OP_CONNECT, OP_WRITE, OP_READ + +import socket + +class error(Exception): pass + +POLLIN = 1 +POLLOUT = 2 + +# The following event types are completely ignored on jython +# Java does not support them, AFAICT +# They are declared only to support code compatibility with cpython + +POLLPRI = 4 +POLLERR = 8 +POLLHUP = 16 +POLLNVAL = 32 + +class poll: + + def __init__(self): + self.selector = java.nio.channels.Selector.open() + self.chanmap = {} + self.unconnected_sockets = [] + + def _getselectable(self, socket_object): + for st in socket.SocketTypes: + if isinstance(socket_object, st): + try: + return socket_object.getchannel() + except: + return None + raise error("Object '%s' is not watchable" % socket_object, 10038) + + def _register_channel(self, socket_object, channel, mask): + jmask = 0 + if mask & POLLIN: + # Note that OP_READ is NOT a valid event on server socket channels. + if channel.validOps() & OP_ACCEPT: + jmask = OP_ACCEPT + else: + jmask = OP_READ + if mask & POLLOUT: + jmask |= OP_WRITE + if channel.validOps() & OP_CONNECT: + jmask |= OP_CONNECT + selectionkey = channel.register(self.selector, jmask) + self.chanmap[channel] = (socket_object, selectionkey) + + def _check_unconnected_sockets(self): + temp_list = [] + for socket_object, mask in self.unconnected_sockets: + channel = self._getselectable(socket_object) + if channel is not None: + self._register_channel(socket_object, channel, mask) + else: + temp_list.append( (socket_object, mask) ) + self.unconnected_sockets = temp_list + + def register(self, socket_object, mask = POLLIN|POLLOUT|POLLPRI): + channel = self._getselectable(socket_object) + if channel is None: + # The socket is not yet connected, and thus has no channel + # Add it to a pending list, and return + self.unconnected_sockets.append( (socket_object, mask) ) + return + self._register_channel(socket_object, channel, mask) + + def unregister(self, socket_object): + channel = self._getselectable(socket_object) + self.chanmap[channel][1].cancel() + del self.chanmap[channel] + + def _dopoll(self, timeout=None): + if timeout is None or timeout < 0: + self.selector.select() + elif timeout == 0: + self.selector.selectNow() + else: + # No multiplication required: both cpython and java use millisecond timeouts + self.selector.select(timeout) + # The returned selectedKeys cannot be used from multiple threads! + return self.selector.selectedKeys() + + def poll(self, timeout=None): + self._check_unconnected_sockets() + selectedkeys = self._dopoll(timeout) + results = [] + for k in selectedkeys.iterator(): + jmask = k.readyOps() + pymask = 0 + if jmask & OP_READ: pymask |= POLLIN + if jmask & OP_WRITE: pymask |= POLLOUT + if jmask & OP_ACCEPT: pymask |= POLLIN + if jmask & OP_CONNECT: pymask |= POLLOUT + # Now return the original userobject, and the return event mask + results.append( (self.chanmap[k.channel()][0], pymask) ) + return results + + def close(self): + for k in self.selector.keys(): + k.cancel() + self.selector.close() + +def _calcselecttimeoutvalue(value): + if value is None: + return None + try: + floatvalue = float(value) + except Exception, x: + raise TypeError("Select timeout value must be a number or None") + if value < 0: + raise error("Select timeout value cannot be negative", 10022) + if floatvalue < 0.000001: + return 0 + return int(floatvalue * 1000) # Convert to milliseconds + +def select ( read_fd_list, write_fd_list, outofband_fd_list, timeout=None): + timeout = _calcselecttimeoutvalue(timeout) + # First create a poll object to do the actual watching. + pobj = poll() + already_registered = {} + # Check the read list + try: + # AMAK: Need to remove all this list searching, change to a dictionary? + for fd in read_fd_list: + mask = POLLIN + if fd in write_fd_list: + mask |= POLLOUT + pobj.register(fd, mask) + already_registered[fd] = 1 + # And now the write list + for fd in write_fd_list: + if not already_registered.has_key(fd): + pobj.register(fd, POLLOUT) + results = pobj.poll(timeout) + except AttributeError, ax: + if str(ax) == "__getitem__": + raise TypeError(ax) + raise ax + # Now start preparing the results + read_ready_list, write_ready_list, oob_ready_list = [], [], [] + for fd, mask in results: + if mask & POLLIN: + read_ready_list.append(fd) + if mask & POLLOUT: + write_ready_list.append(fd) + pobj.close() + return read_ready_list, write_ready_list, oob_ready_list + Deleted: trunk/jython/Lib/socket.py =================================================================== --- trunk/jython/Lib/socket.py 2007-06-18 05:03:36 UTC (rev 3255) +++ trunk/jython/Lib/socket.py 2007-06-18 06:17:10 UTC (rev 3256) @@ -1,434 +0,0 @@ -"""Preliminary socket module. - -XXX Restrictions: - -- Only INET sockets -- No asynchronous behavior -- No socket options -- Can't do a very good gethostbyaddr() right... -- 20050527: updated by Alan Kennedy to support socket timeouts. -""" - -import java.io.InterruptedIOException -import java.net.DatagramSocket -import java.net.DatagramPacket -import java.net.InetAddress -import java.net.InetSocketAddress -import java.net.ServerSocket -import java.net.Socket -import java.net.SocketTimeoutException -import org.python.core.PyFile -import jarray -import string - -__all__ = [ 'AF_INET', 'SO_REUSEADDR', 'SOCK_DGRAM', 'SOCK_RAW', - 'SOCK_RDM', 'SOCK_SEQPACKET', 'SOCK_STREAM', 'SOL_SOCKET', - 'SocketType', 'error', 'getfqdn', 'gethostbyaddr', - 'gethostbyname', 'gethostname', 'socket', 'getaddrinfo'] - -error = IOError -class timeout(error): pass - -AF_INET = 2 - -SOCK_DGRAM = 1 -SOCK_STREAM = 2 -SOCK_RAW = 3 # not supported -SOCK_RDM = 4 # not supported -SOCK_SEQPACKET = 5 # not supported -SOL_SOCKET = 0xFFFF -SO_REUSEADDR = 4 - -def _gethostbyaddr(name): - # This is as close as I can get; at least the types are correct... - addresses = java.net.InetAddress.getAllByName(gethostbyname(name)) - names = [] - addrs = [] - for addr in addresses: - names.append(addr.getHostName()) - addrs.append(addr.getHostAddress()) - return (names, addrs) - -def getfqdn(name=None): - """ - Return a fully qualified domain name for name. If name is omitted or empty - it is interpreted as the local host. To find the fully qualified name, - the hostname returned by gethostbyaddr() is checked, then aliases for the - host, if available. The first name which includes a period is selected. - In case no fully qualified domain name is available, the hostname is retur - New in version 2.0. - """ - if not name: - name = gethostname() - names, addrs = _gethostbyaddr(name) - for a in names: - if a.find(".") >= 0: - return a - return name - -def gethostname(): - return java.net.InetAddress.getLocalHost().getHostName() - -def gethostbyname(name): - return java.net.InetAddress.getByName(name).getHostAddress() - -def gethostbyaddr(name): - names, addrs = _gethostbyaddr(name) - return (names[0], names, addrs) - -def socket(family = AF_INET, type = SOCK_STREAM, flags=0): - assert family == AF_INET - assert type in (SOCK_DGRAM, SOCK_STREAM) - assert flags == 0 - if type == SOCK_STREAM: - return _tcpsocket() - else: - return _udpsocket() - -def getaddrinfo(host, port, family=0, socktype=SOCK_STREAM, proto=0, flags=0): - return ( (AF_INET, socktype, 0, "", (gethostbyname(host), port)), ) - -_defaulttimeout = None - -def getdefaulttimeout(): - return _defaulttimeout - -def _get_timeout_value(value): - if value is None: - return None - try: - floatval = float(value) - except ValueError: - raise TypeError('A float is required') - if floatval < 0: - raise ValueError('Timeout value out of range') - if floatval < 0.001: # 1 millisecond - # java interprets a zero timeout as an infinite timeout - # python interprets a zero timeout as equivalent to non-blocking - # we cannot represent python semantics for a zero timeout on - # java (if we want it to work on pre 1.4 JVMs) - # so we use the shortest timeout possible, 1.1 millisecond - return 0.0011 - return floatval - -def setdefaulttimeout(timeout): - try: - global _defaulttimeout - _defaulttimeout = _get_timeout_value(timeout) - finally: - _tcpsocket.timeout = _defaulttimeout - -class _tcpsocket: - - sock = None - istream = None - ostream = None - addr = None - server = 0 - file_count = 0 - reuse_addr = 0 - - def __init__(self): - self.timeout = _defaulttimeout - - def bind(self, addr, port=None): - if port is not None: - addr = (addr, port) - assert not self.sock - assert not self.addr - host, port = addr # format check - self.addr = addr - - def listen(self, backlog=50): - "This signifies a server socket" - assert not self.sock - self.server = 1 - if self.addr: - host, port = self.addr - else: - host, port = "", 0 - if host: - a = java.net.InetAddress.getByName(host) - self.sock = java.net.ServerSocket(port, backlog, a) - else: - self.sock = java.net.ServerSocket(port, backlog) - if hasattr(self.sock, "setReuseAddress"): - self.sock.setReuseAddress(self.reuse_addr) - - def accept(self): - "This signifies a server socket" - if not self.sock: - self.listen() - assert self.server - if self.timeout: - self.sock.setSoTimeout(int(self.timeout*1000)) - try: - sock = self.sock.accept() - except java.net.SocketTimeoutException, jnste: - raise timeout('timed out') - host = sock.getInetAddress().getHostName() - port = sock.getPort() - conn = _tcpsocket() - conn._setup(sock) - return conn, (host, port) - - def connect(self, addr, port=None): - "This signifies a client socket" - if port is not None: - addr = (addr, port) - assert not self.sock - host, port = addr - if host == "": - host = java.net.InetAddress.getLocalHost() - try: - cli_sock = java.net.Socket() - addr = java.net.InetSocketAddress(host, port) - if self.timeout: - cli_sock.connect(addr, int(self.timeout*1000)) - else: - cli_sock.connect(addr) - self._setup(cli_sock) - except java.net.SocketTimeoutException, jnste: - raise timeout('timed out') - - def _setup(self, sock): - self.sock = sock - if hasattr(self.sock, "setReuseAddress"): - self.sock.setReuseAddress(self.reuse_addr) - self.istream = sock.getInputStream() - self.ostream = sock.getOutputStream() - - def recv(self, n): - assert self.sock - data = jarray.zeros(n, 'b') - try: - m = self.istream.read(data) - except java.io.InterruptedIOException , jiiie: - raise timeout('timed out') - if m <= 0: - return "" - if m < n: - data = data[:m] - return data.tostring() - - def send(self, s): - assert self.sock - n = len(s) - self.ostream.write(s) - return n - - sendall = send - - def getsockname(self): - if not self.sock: - host, port = self.addr or ("", 0) - host = java.net.InetAddress.getByName(host).getHostAddress() - else: - if self.server: - host = self.sock.getInetAddress().getHostAddress() - else: - host = self.sock.getLocalAddress().getHostAddress() - port = self.sock.getLocalPort() - return (host, port) - - def getpeername(self): - assert self.sock - assert not self.server - host = self.sock.getInetAddress().getHostAddress() - port = self.sock.getPort() - return (host, port) - - def setsockopt(self, level, optname, value): - if optname == SO_REUSEADDR: - self.reuse_addr = value - - def getsockopt(self, level, optname): - if optname == SO_REUSEADDR: - return self.reuse_addr - - def makefile(self, mode="r", bufsize=-1): - file = None - if self.istream: - if self.ostream: - file = org.python.core.PyFile(self.istream, self.ostream, - "<socket>", mode) - else: - file = org.python.core.PyFile(self.istream, "<socket>", mode) - elif self.ostream: - file = org.python.core.PyFile(self.ostream, "<socket>", mode) - else: - raise IOError, "both istream and ostream have been shut down" - if file: - return _tcpsocket.FileWrapper(self, file) - - class FileWrapper: - def __init__(self, socket, file): - self.socket = socket - self.sock = socket.sock - self.istream = socket.istream - self.ostream = socket.ostream - - self.file = file - self.read = file.read - self.readline = file.readline - self.readlines = file.readlines - self.write = file.write - self.writelines = file.writelines - self.flush = file.flush - self.seek = file.seek - self.tell = file.tell - - self.socket.file_count += 1 - - def close(self): - if self.file.closed: - # Already closed - return - - self.socket.file_count -= 1 - self.file.close() - - if self.socket.file_count == 0 and self.socket.sock == 0: - # This is the last file Only close the socket and streams - # if there are no outstanding files left. - if self.sock: - self.sock.close() - if self.istream: - self.istream.close() - if self.ostream: - self.ostream.close() - - def shutdown(self, how): - assert how in (0, 1, 2) - assert self.sock - if how in (0, 2): - self.istream = None - if how in (1, 2): - self.ostream = None - - def close(self): - if not self.sock: - return - sock = self.sock - istream = self.istream - ostream = self.ostream - self.sock = 0 - self.istream = 0 - self.ostream = 0 - # Only close the socket and streams if there are no - # outstanding files left. - if self.file_count == 0: - if istream: - istream.close() - if ostream: - ostream.close() - if sock: - sock.close() - - def gettimeout(self): - return self.timeout - - def settimeout(self, timeout): - self.timeout = _get_timeout_value(timeout) - if self.timeout and self.sock: - self.sock.setSoTimeout(int(self.timeout*1000)) - -class _udpsocket: - - def __init__(self): - self.sock = None - self.addr = None - - def bind(self, addr, port=None): - if port is not None: - addr = (addr, port) - assert not self.sock - host, port = addr - if host == "": - self.sock = java.net.DatagramSocket(port) - else: - a = java.net.InetAddress.getByName(host) - self.sock = java.net.DatagramSocket(port, a) - - def connect(self, addr, port=None): - if port is not None: - addr = (addr, port) - host, port = addr # format check - assert not self.addr - if not self.sock: - self.sock = java.net.DatagramSocket() - self.addr = addr # convert host to InetAddress instance? - - def sendto(self, data, addr): - n = len(data) - if not self.sock: - self.sock = java.net.DatagramSocket() - host, port = addr - bytes = jarray.array(map(ord, data), 'b') - a = java.net.InetAddress.getByName(host) - packet = java.net.DatagramPacket(bytes, n, a, port) - self.sock.send(packet) - return n - - def send(self, data): - assert self.addr - return self.sendto(data, self.addr) - - def recvfrom(self, n): - assert self.sock - bytes = jarray.zeros(n, 'b') - packet = java.net.DatagramPacket(bytes, n) - self.sock.receive(packet) - host = packet.getAddress().getHostName() - port = packet.getPort() - m = packet.getLength() - if m < n: - bytes = bytes[:m] - return bytes.tostring(), (host, port) - - def recv(self, n): - assert self.sock - bytes = jarray.zeros(n, 'b') - packet = java.net.DatagramPacket(bytes, n) - self.sock.receive(packet) - m = packet.getLength() - if m < n: - bytes = bytes[:m] - return bytes.tostring() - - def getsockname(self): - assert self.sock - host = self.sock.getLocalAddress().getHostName() - port = self.sock.getLocalPort() - return (host, port) - - def getpeername(self): - assert self.sock - host = self.sock.getInetAddress().getHostName() - port = self.sock.getPort() - return (host, port) - - def __del__(self): - self.close() - - def close(self): - if not self.sock: - return - sock = self.sock - self.sock = 0 - sock.close() - -SocketType = _tcpsocket - -def test(): - s = socket(AF_INET, SOCK_STREAM) - s.connect(("", 80)) - s.send("GET / HTTP/1.0\r\n\r\n") - while 1: - data = s.recv(2000) - print data - if not data: - break - -if __name__ == '__main__': - test() Copied: trunk/jython/Lib/socket.py (from rev 3255, trunk/sandbox/kennedya/asynch_sockets/socket.py) =================================================================== --- trunk/jython/Lib/socket.py (rev 0) +++ trunk/jython/Lib/socket.py 2007-06-18 06:17:10 UTC (rev 3256) @@ -0,0 +1,807 @@ +""" +This is an updated socket module for use on JVMs > 1.4; it is derived from the +old jython socket module. +The primary extra it provides is non-blocking support. + +XXX Restrictions: + +- Only INET sockets +- No asynchronous behavior +- No socket options +- Can't do a very good gethostbyaddr() right... +AMAK: 20050527: added socket timeouts +AMAK: 20070515: Added non-blocking (asynchronous) support +AMAK: 20070515: Added client-side SSL support +""" + +_defaulttimeout = None + +import threading +import time +import types +import jarray +import string +import sys + +import java.io.BufferedInputStream +import java.io.BufferedOutputStream +import java.io.InterruptedIOException +import java.lang.Exception +import java.lang.String +import java.net.BindException +import java.net.ConnectException +import java.net.DatagramPacket +import java.net.InetAddress +import java.net.InetSocketAddress +import java.net.Socket +import java.net.SocketTimeoutException +import java.nio.ByteBuffer +import java.nio.channels.DatagramChannel +import java.nio.channels.IllegalBlockingModeException +import java.nio.channels.ServerSocketChannel +import java.nio.channels.SocketChannel +import javax.net.ssl.SSLSocketFactory +import org.python.core.PyFile + +# Some errno constants, until we establish a separate errno module. + +ERRNO_EACCESS = 10035 +ERRNO_EWOULDBLOCK = 10035 +ERRNO_EINPROGRESS = 10036 +ERRNO_ECONNREFUSED = 10061 + +class error(Exception): pass +class herror(error): pass +class gaierror(error): pass +class timeout(error): pass + +ALL = None + +exception_map = { + +# (<javaexception>, <circumstance>) : lambda: <code that raises the python equivalent> + +(java.io.InterruptedIOException, ALL) : lambda exc: timeout('timed out'), +(java.net.BindException, ALL) : lambda exc: error(ERRNO_EACCESS, 'Permission denied'), +(java.net.ConnectException, ALL) : lambda exc: error( (ERRNO_ECONNREFUSED, 'Connection refused') ), +(java.net.SocketTimeoutException, ALL) : lambda exc: timeout('timed out'), + +} + +def would_block_error(exc=None): + return error( (ERRNO_EWOULDBLOCK, 'The socket operation could not complete without blocking') ) + +def map_exception(exc, circumstance=ALL): + try: +# print "Mapping exception: %s" % str(exc) + return exception_map[(exc.__class__, circumstance)](exc) + except KeyError: + return error('Unmapped java exception: %s' % exc.toString()) + +exception_map.update({ + (java.nio.channels.IllegalBlockingModeException, ALL) : would_block_error, + }) + +MODE_BLOCKING = 'block' +MODE_NONBLOCKING = 'nonblock' +MODE_TIMEOUT = 'timeout' + +_permitted_modes = (MODE_BLOCKING, MODE_NONBLOCKING, MODE_TIMEOUT) + +class _nio_impl: + + timeout = None + mode = MODE_BLOCKING + + def read(self, buf): + bytebuf = java.nio.ByteBuffer.wrap(buf) + count = self.jchannel.read(bytebuf) + return count + + def write(self, buf): + bytebuf = java.nio.ByteBuffer.wrap(buf) + count = self.jchannel.write(bytebuf) + return count + + def _setreuseaddress(self, flag): + self.jsocket.setReuseAddress(flag) + + def _getreuseaddress(self, flag): + return self.jsocket.getReuseAddress() + + def getpeername(self): + return (self.jsocket.getInetAddress().getHostName(), self.jsocket.getPort() ) + + def config(self, mode, timeout): + self.mode = mode + if self.mode == MODE_BLOCKING: + self.jchannel.configureBlocking(1) + if self.mode == MODE_NONBLOCKING: + self.jchannel.configureBlocking(0) + if self.mode == MODE_TIMEOUT: + # self.channel.configureBlocking(0) + self.jsocket.setSoTimeout(int(timeout*1000)) + + def close1(self): + self.jsocket.close() + + def close2(self): + self.jchannel.close() + + def close3(self): + if not self.jsocket.isClosed(): + self.jsocket.close() + + def close4(self): + if not self.jsocket.isClosed(): + if hasattr(self.jsocket, 'shutdownInput') and not self.jsocket.isInputShutdown(): + self.jsocket.shutdownInput() + if hasattr(self.jsocket, 'shutdownOutput') and not self.jsocket.isOutputShutdown(): + self.jsocket.shutdownOutput() + self.jsocket.close() + + close = close1 +# close = close2 +# close = close3 +# close = close4 + + def getchannel(self): + return self.jchannel + + fileno = getchannel + +class _client_socket_impl(_nio_impl): + + def __init__(self, socket=None): + if socket: + self.jchannel = socket.getChannel() + self.host = socket.getInetAddress().getHostName() + self.port = socket.getPort() + else: + self.jchannel = java.nio.channels.SocketChannel.open() + self.host = None + self.port = None + self.jsocket = self.jchannel.socket() + + def bind(self, host, port): + self.jsocket.bind(java.net.InetSocketAddress(host, port)) + + def connect(self, host, port): + self.host = host + self.port = port + self.jchannel.connect(java.net.InetSocketAddress(self.host, self.port)) + + def finish_connect(self): + return self.jchannel.finishConnect() + + def close(self): + _nio_impl.close(self) + +class _server_socket_impl(_nio_impl): + + def __init__(self, host, port, backlog, reuse_addr): + self.jchannel = java.nio.channels.ServerSocketChannel.open() + self.jsocket = self.jchannel.socket() + if host: + bindaddr = java.net.InetSocketAddress(host, port) + else: + bindaddr = java.net.InetSocketAddress(port) + self._setreuseaddress(reuse_addr) + self.jsocket.bind(bindaddr, backlog) + + def accept(self): + try: + if self.mode in (MODE_BLOCKING, MODE_NONBLOCKING): + new_cli_chan = self.jchannel.accept() + if new_cli_chan != None: + return _client_socket_impl(new_cli_chan.socket()) + else: + return None + else: + # In timeout mode now + new_cli_sock = self.jsocket.accept() + return _client_socket_impl(new_cli_sock) + except java.lang.Exception, jlx: + raise map_exception(jlx) + + def close(self): + _nio_impl.close(self) + +class _datagram_socket_impl(_nio_impl): + + def __init__(self, port=None, address=None, reuse_addr=0): + self.jchannel = java.nio.channels.DatagramChannel.open() + self.jsocket = self.jchannel.socket() + if port: + if address is not None: + local_address = java.net.InetSocketAddress(address, port) + else: + local_address = java.net.InetSocketAddress(port) + self.jsocket.bind(local_address) + self._setreuseaddress(reuse_addr) + + def connect(self, host, port): + self.jchannel.connect(java.net.InetSocketAddress(host, port)) + + def finish_connect(self): + return self.jchannel.finishConnect() + + def receive(self, packet): + self.jsocket.receive(packet) + + def send(self, packet): + self.jsocket.send(packet) + +__all__ = [ 'AF_INET', 'SO_REUSEADDR', 'SOCK_DGRAM', 'SOCK_RAW', + 'SOCK_RDM', 'SOCK_SEQPACKET', 'SOCK_STREAM', 'SOL_SOCKET', + 'SocketType', 'SocketTypes', 'error', 'herror', 'gaierror', 'timeout', + 'getfqdn', 'gethostbyaddr', 'gethostbyname', 'gethostname', + 'socket', 'getaddrinfo', 'getdefaulttimeout', 'setdefaulttimeout', + 'has_ipv6', 'htons', 'htonl', 'ntohs', 'ntohl', + ] + +AF_INET = 2 + +SOCK_DGRAM = 1 +SOCK_STREAM = 2 +SOCK_RAW = 3 # not supported +SOCK_RDM = 4 # not supported +SOCK_SEQPACKET = 5 # not supported +SOL_SOCKET = 0xFFFF +SO_REUSEADDR = 4 + +def _gethostbyaddr(name): + # This is as close as I can get; at least the types are correct... + addresses = java.net.InetAddress.getAllByName(gethostbyname(name)) + names = [] + addrs = [] + for addr in addresses: + names.append(addr.getHostName()) + addrs.append(addr.getHostAddress()) + return (names, addrs) + +def getfqdn(name=None): + """ + Return a fully qualified domain name for name. If name is omitted or empty + it is interpreted as the local host. To find the fully qualified name, + the hostname returned by gethostbyaddr() is checked, then aliases for the + host, if available. The first name which includes a period is selected. + In case no fully qualified domain name is available, the hostname is retur + New in version 2.0. + """ + if not name: + name = gethostname() + names, addrs = _gethostbyaddr(name) + for a in names: + if a.find(".") >= 0: + return a + return name + +def gethostname(): + return java.net.InetAddress.getLocalHost().getHostName() + +def gethostbyname(name): + return java.net.InetAddress.getByName(name).getHostAddress() + +def gethostbyaddr(name): + names, addrs = _gethostbyaddr(name) + return (names[0], names, addrs) + +def getservbyname(servicename, protocolname=None): + # http://bugs.sun.com/bugdatabase/view_bug.do?bug_id=4071389 + # How complex is the structure of /etc/services? + raise NotImplementedError("getservbyname not yet supported on jython.") + +def getservbyport(port, protocolname=None): + # Same situation as above + raise NotImplementedError("getservbyport not yet supported on jython.") + +def getprotobyname(protocolname=None): + # Same situation as above + raise NotImplementedError("getprotobyname not yet supported on jython.") + +def socket(family = AF_INET, type = SOCK_STREAM, flags=0): + assert family == AF_INET + assert type in (SOCK_DGRAM, SOCK_STREAM) + assert flags == 0 + if type == SOCK_STREAM: + return _tcpsocket() + else: + return _udpsocket() + +def getaddrinfo(host, port, family=0, socktype=SOCK_STREAM, proto=0, flags=0): + return ( (AF_INET, socktype, 0, "", (gethostbyname(host), port)), ) + +has_ipv6 = 1 + +def getnameinfo(sock_addr, flags): + raise NotImplementedError("getnameinfo not yet supported on jython.") + +def getdefaulttimeout(): + return _defaulttimeout + +def _calctimeoutvalue(value): + if value is None: + return None + try: + floatvalue = float(value) + except: + raise TypeError('Socket timeout value must be a number or None') + if floatvalue < 0: + raise ValueError("Socket timeout value cannot be negative") + if floatvalue < 0.000001: + return 0.0 + return floatvalue + +def setdefaulttimeout(timeout): + global _defaulttimeout + try: + _defaulttimeout = _calctimeoutvalue(timeout) + finally: + _nonblocking_api_mixin.timeout = _defaulttimeout + +def htons(x): return x +def htonl(x): return x +def ntohs(x): return x +def ntohl(x): return x + +class _nonblocking_api_mixin: + + timeout = _defaulttimeout + mode = MODE_BLOCKING + + def gettimeout(self): + return self.timeout + + def settimeout(self, timeout): + self.timeout = _calctimeoutvalue(timeout) + if self.timeout is None: + self.mode = MODE_BLOCKING + elif self.timeout < 0.000001: + self.mode = MODE_NONBLOCKING + else: + self.mode = MODE_TIMEOUT + self._config() + + def setblocking(self, flag): + if flag: + self.mode = MODE_BLOCKING + self.timeout = None + else: + self.mode = MODE_NONBLOCKING + self.timeout = 0.0 + self._config() + + def _config(self): + assert self.mode in _permitted_modes + if self.sock_impl: self.sock_impl.config(self.mode, self.timeout) + + def getchannel(self): + if not self.sock_impl: + return None + return self.sock_impl.getchannel() +# if hasattr(self.sock_impl, 'getchannel'): +# return self.sock_impl.getchannel() +# raise error('Operation not implemented on this JVM') + + fileno = getchannel + + def _get_jsocket(self): + return self.sock_impl.jsocket + +def _unpack_address_tuple(address_tuple): + error_message = "Address must be a tuple of (hostname, port)" + if type(address_tuple) is not type( () ) \ + or type(address_tuple[0]) is not type("") \ + or type(address_tuple[1]) is not type(0): + raise TypeError(error_message) + return address_tuple[0], address_tuple[1] + +class _tcpsocket(_nonblocking_api_mixin): + + sock_impl = None + istream = None + ostream = None + local_addr = None + server = 0 + file_count = 0 + #reuse_addr = 1 + reuse_addr = 0 + + def bind(self, addr): + assert not self.sock_impl + assert not self.local_addr + # Do the address format check + host, port = _unpack_address_tuple(addr) + self.local_addr = addr + + def listen(self, backlog=50): + "This signifies a server socket" + try: + assert not self.sock_impl + self.server = 1 + if self.local_addr: + host, port = self.local_addr + else: + host, port = "", 0 + self.sock_impl = _server_socket_impl(host, port, backlog, self.reuse_addr) + self._config() + except java.lang.Exception, jlx: + raise map_exception(jlx) + +# +# The following has information on a java.lang.NullPointerException problem I'm having +# +# http://developer.java.sun.com/developer/bugParade/bugs/4801882.html + + def accept(self): + "This signifies a server socket" + try: + if not self.sock_impl: + self.listen() + assert self.server + new_sock = self.sock_impl.accept() + if not new_sock: + raise would_block_error() + cliconn = _tcpsocket() + cliconn._setup(new_sock) + return cliconn, new_sock.getpeername() + except java.lang.Exception, jlx: + raise map_exception(jlx) + + def _get_host_port(self, addr): + host, port = _unpack_address_tuple(addr) + if host == "": + host = java.net.InetAddress.getLocalHost() + return host, port + + def _do_connect(self, addr): + try: + assert not self.sock_impl + host, port = self._get_host_port(addr) + self.sock_impl = _client_socket_impl() + if self.local_addr: # Has the socket been bound to a local address? + bind_host, bind_port = self.local_addr + self.sock_impl.bind(bind_host, bind_port) + self._config() # Configure timeouts, etc, now that the socket exists + self.sock_impl.connect(host, port) + self._setup(self.sock_impl) + except java.lang.Exception, jlx: + raise map_exception(jlx) + + def connect(self, addr): + "This signifies a client socket" + self._do_connect(addr) + self._setup(self.sock_impl) + + def connect_ex(self, addr): + "This signifies a client socket" + self._do_connect(addr) + if self.sock_impl.finish_connect(): + self._setup(self.sock_impl) + return 0 + return ERRNO_EINPROGRESS + + def _setup(self, sock): + self.sock_impl = sock + self.sock_impl._setreuseaddress(self.reuse_addr) + if self.mode != MODE_NONBLOCKING: + self.istream = self.sock_impl.jsocket.getInputStream() + self.ostream = self.sock_impl.jsocket.getOutputStream() + + def recv(self, n): + try: + if not self.sock_impl: raise error('Socket not open') + if self.sock_impl.jchannel.isConnectionPending(): + self.sock_impl.jchannel.finishConnect() + data = jarray.zeros(n, 'b') + m = self.sock_impl.read(data) + if m <= 0: + if self.mode == MODE_NONBLOCKING: + raise would_block_error() + return "" + if m < n: + data = data[:m] + return data.tostring() + except java.lang.Exception, jlx: + raise map_exception(jlx) + + def recvfrom(self, n): + return self.recv(n), None + + def send(self, s): + if not self.sock_impl: raise error('Socket not open') + if self.sock_impl.jchannel.isConnectionPending(): + self.sock_impl.jchannel.finishConnect() + #n = len(s) + numwritten = self.sock_impl.write(s) + return numwritten + + sendall = send + + def getsockname(self): + if not self.sock_impl: + host, port = self.local_addr or ("", 0) + host = java.net.InetAddress.getByName(host).getHostAddress() + else: + if self.server: + host = self.sock_impl.jsocket.getInetAddress().getHostAddress() + else: + host = self.sock_impl.jsocket.getLocalAddress().getHostAddress() + port = self.sock_impl.jsocket.getLocalPort() + return (host, port) + + def getpeername(self): + assert self.sock_impl + assert not self.server + host = self.sock_impl.jsocket.getInetAddress().getHostAddress() + port = self.sock_impl.jsocket.getPort() + return (host, port) + + def setsockopt(self, level, optname, value): + if optname == SO_REUSEADDR: + self.reuse_addr = value + + def getsockopt(self, level, optname): + if optname == SO_REUSEADDR: + return self.reuse_addr + + def makefile(self, mode="r", bufsize=-1): + file = None + if self.istream: + if self.ostream: + file = org.python.core.PyFile(self.istream, self.ostream, + "<socket>", mode) + else: + file = org.python.core.PyFile(self.istream, "<socket>", mode) + elif self.ostream: + file = org.python.core.PyFile(self.ostream, "<socket>", mode) + else: + raise IOError, "both istream and ostream have been shut down" + if file: + return _tcpsocket.FileWrapper(self, file) + + class FileWrapper: + def __init__(self, socket, file): + self.socket = socket + self.sock = socket.sock_impl + self.istream = socket.istream + self.ostream = socket.ostream + + self.file = file + self.read = file.read + self.readline = file.readline + self.readlines = file.readlines + self.write = file.write + self.writelines = file.writelines + self.flush = file.flush + self.seek = file.seek + self.tell = file.tell + self.closed = file.closed + + self.socket.file_count += 1 + + def close(self): + if self.file.closed: + # Already closed + return + + self.socket.file_count -= 1 + self.file.close() + self.closed = self.file.closed + + if self.socket.file_count == 0 and self.socket.sock_impl == 0: + # This is the last file Only close the socket and streams + # if there are no outstanding files left. + if self.sock: + self.sock.close() + if self.istream: + self.istream.close() + if self.ostream: + self.ostream.close() + + def shutdown(self, how): + assert how in (0, 1, 2) + assert self.sock_impl + if how in (0, 2): + self.istream = None + if how in (1, 2): + self.ostream = None + + def close(self): + if not self.sock_impl: + return + sock = self.sock_impl + istream = self.istream + ostream = self.ostream + self.sock_impl = 0 + self.istream = 0 + self.ostream = 0 + # Only close the socket and streams if there are no + # outstanding files left. + if self.file_count == 0: + if istream: + istream.close() + if ostream: + ostream.close() + if sock: + sock.close() + +class _udpsocket(_nonblocking_api_mixin): + + def __init__(self): + self.sock_impl = None + self.addr = None + self.reuse_addr = 0 + + def bind(self, addr): + assert not self.sock_impl + host, port = _unpack_address_tuple(addr) + host_address = java.net.InetAddress.getByName(host) + self.sock_impl = _datagram_socket_impl(port, host_address, reuse_addr = self.reuse_addr) + self._config() + + def connect(self, addr): + host, port = _unpack_address_tuple(addr) + assert not self.addr + if not self.sock_impl: + self.sock_impl = _datagram_socket_impl() + self._config() + self.sock_impl.connect(host, port) + self.addr = addr # convert host to InetAddress instance? + + def connect_ex(self, addr): + host, port = _unpack_address_tuple(addr) + assert not self.addr + self.addr = addr + if not self.sock_impl: + self.sock_impl = _datagram_socket_impl() + self._config() + self.sock_impl.connect(host, port) + if self.sock_impl.finish_connect(): + return 0 + return ERRNO_EINPROGRESS + + def sendto(self, data, p1, p2=None): + if not p2: + flags, addr = 0, p1 + else: + flags, addr = 0, p2 + n = len(data) + if not self.sock_impl: + self.sock_impl = _datagram_socket_impl() + host, port = addr + bytes = java.lang.String(data).getBytes('iso-8859-1') + a = java.net.InetAddress.getByName(host) + packet = java.net.DatagramPacket(bytes, n, a, port) + self.sock_impl.send(packet) + return n + + def send(self, data): + assert self.addr + return self.sendto(data, self.addr) + + def recvfrom(self, n): + try: + assert self.sock_impl + bytes = jarray.zeros(n, 'b') + packet = java.net.DatagramPacket(bytes, n) + self.sock_impl.receive(packet) + host = None + if packet.getAddress(): + host = packet.getAddress().getHostName() + port = packet.getPort() + m = packet.getLength() + if m < n: + bytes = bytes[:m] + return bytes.tostring(), (host, port) + except java.lang.Exception, jlx: + raise map_exception(jlx) + + def recv(self, n): + try: + assert self.sock_impl + bytes = jarray.zeros(n, 'b') + packet = java.net.DatagramPacket(bytes, n) + self.sock_impl.receive(packet) + m = packet.getLength() + if m < n: + bytes = bytes[:m] + return bytes.tostring() + except java.lang.Exception, jlx: + raise map_exception(jlx) + + def getsockname(self): + assert self.sock_impl + host = self.sock_impl.jsocket.getLocalAddress().getHostName() + port = self.sock_impl.jsocket.getLocalPort() + return (host, port) + + def getpeername(self): + assert self.sock + host = self.sock_impl.jsocket.getInetAddress().getHostName() + port = self.sock_impl.jsocket.getPort() + return (host, port) + + def __del__(self): + self.close() + + def close(self): + if not self.sock_impl: + return + sock = self.sock_impl + self.sock_impl = None + sock.close() + + def setsockopt(self, level, optname, value): + if optname == SO_REUSEADDR: + self.reuse_addr = value +# self.sock._setreuseaddress(value) + + def getsockopt(self, level, optname): + if optname == SO_REUSEADDR: + return self.sock_impl._getreuseaddress() + else: + return None + +SocketType = _tcpsocket +SocketTypes = [_tcpsocket, _udpsocket] + +# Define the SSL support + +class ssl: + + def __init__(self, plain_sock, keyfile=None, certfile=None): + self.ssl_sock = self.make_ssl_socket(plain_sock) + + def make_ssl_socket(self, plain_socket, auto_close=0): + java_net_socket = plain_socket._get_jsocket() + assert isinstance(java_net_socket, java.net.Socket) + host = java_net_socket.getInetAddress().getHostName() + port = java_net_socket.getPort() + factory = javax.net.ssl.SSLSocketFactory.getDefault(); + ssl_socket = factory.createSocket(java_net_socket, host, port, auto_close) + ssl_socket.setEnabledCipherSuites(ssl_socket.getSupportedCipherSuites()) + ssl_socket.startHandshake() + return ssl_socket + + def read(self, n=4096): + # Probably needs some work on efficency + in_buf = java.io.BufferedInputStream(self.ssl_sock.getInputStream()) + data = jarray.zeros(n, 'b') + m = in_buf.read(data, 0, n) + if m <= 0: + return "" + if m < n: + data = data[:m] + return data.tostring() + + def write(self, s): + # Probably needs some work on efficency + out = java.io.BufferedOutputStream(self.ssl_sock.getOutputStream()) + out.write(s) + out.flush() + + def _get_server_cert(self): + return self.ssl_sock.getSession().getPeerCertificates()[0] + + def server(self): + cert = self._get_server_cert() + return cert.getSubjectDN().toString() + + def issuer(self): + cert = self._get_server_cert() + return cert.getIssuerDN().toString() + +def test(): + s = socket(AF_INET, SOCK_STREAM) + s.connect(("", 80)) + s.send("GET / HTTP/1.0\r\n\r\n") + while 1: + data = s.recv(2000) + print data + if not data: + break + +if __name__ == '__main__': + test() Copied: trunk/jython/Lib/test/test_select.py (from rev 3255, trunk/sandbox/kennedya/asynch_sockets/test/test_select.py) =================================================================== --- trunk/jython/Lib/test/test_select.py (rev 0) +++ trunk/jython/Lib/test/test_select.py 2007-06-18 06:17:10 UTC (rev 3256) @@ -0,0 +1,214 @@ +""" +AMAK: 20050515: This module is the test_select.py from cpython 2.4, ported to jython + unittest +""" + +try: + object +except NameError: + class object: pass + +import socket, select + +import os +import sys +import unittest + +class SelectWrapper: + + def __init__(self): + self.read_fds = [] + self.write_fds = [] + self.oob_fds = [] + self.timeout = None + + def add_read_fd(self, fd): + self.read_fds.append(fd) + + def add_write_fd(self, fd): + self.write_fds.append(fd) + + def add_oob_fd(self, fd): + self.oob_fds.append(fd) + + def set_timeout(self, timeout): + self.timeout = timeout + +class PollWrapper: + + def __init__(self): + self.timeout = None + self.poll_object = select.poll() + + def add_read_fd(self, fd): + self.poll_object.register(fd, select.POLL_IN) + + def add_write_fd(self, fd): + self.poll_object.register(fd, select.POLL_OUT) + + def add_oob_fd(self, fd): + self.poll_object.register(fd, select.POLL_PRI) + +class TestSelectInvalidParameters(unittest.TestCase): + + def testBadSelectSetTypes(self): + # Test some known error conditions + for bad_select_set in [None, 1,]: + for pos in range(2): # OOB not supported on Java + args = [[], [], []] + args[pos] = bad_select_set + try: + timeout = 0 # Can't wait forever + rfd, wfd, xfd = select.select(args[0], args[1], args[2], timeout) + except TypeError: + pass + else: + self.fail("Selecting on '%s' should have raised TypeError" % str(bad_select_set)) + + def testBadSelectableTypes(self): + class Nope: pass + + class Almost1: + def fileno(self): + return 'fileno' + + class Almost2: + def fileno(self): + return 'fileno' + + # Test some known error conditions + for bad_selectable in [None, 1, object(), Nope(), Almost1(), Almost2()]: + try: + timeout = 0 # Can't wait forever + rfd, wfd, xfd = select.select([bad_selectable], [], [], timeout) + except (TypeError, select.error), x: + pass + else: + self.fail("Selecting on '%s' should have raised TypeError or select.error" % str(bad_selectable)) + + def testInvalidTimeoutTypes(self): + for invalid_timeout in ['not a number']: + try: + rfd, wfd, xfd = select.select([], [], [], invalid_timeout) + except TypeError: + pass + else: + self.fail("Invalid timeout value '%s' should have raised TypeError" % invalid_timeout) + + def testInvalidTimeoutValues(self): + for invalid_timeout in [-1]: + try: + rfd, wfd, xfd = select.select([], [], [], invalid_timeout) + except (ValueError, select.error): + pass + else: + self.fail("Invalid timeout value '%s' should have raised ValueError or select.error" % invalid_timeout) + +class TestSelectClientSocket(unittest.TestCase): + + def testUnconnectedSocket(self): + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for x in range(5)] + for pos in range(2): # OOB not supported on Java + args = [[], [], []] + args[pos] = sockets + timeout = 0 # Can't wait forever + rfd, wfd, xfd = select.select(args[0], args[1], args[2], timeout) + for s in sockets: + self.failIf(s in rfd) + self.failIf(s in wfd) + +def check_server_running_on_localhost_port(port_number): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.connect( ('localhost', port_number) ) + s.close() + except: + return 0 + return 1 + +class TestPollClientSocket(unittest.TestCase): + + def testEventConstants(self): + for event_name in ['IN', 'OUT', 'PRI', 'ERR', 'HUP', 'NVAL', ]: + self.failUnless(hasattr(select, 'POLL%s' % event_name)) + + def testSocketRegisteredBeforeConnected(self): + # You MUST be running a server on port 80 for this one to work + if not check_server_running_on_localhost_port(80): + print "Unable to run testSocketRegisteredBeforeConnected: no server on port 80" + return + sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for x in range(5)] + timeout = 1 # Can't wait forever + poll_object = select.poll() + for s in sockets: + # Register the sockets before they are connected + poll_object.register(s, select.POLLOUT) + result_list = poll_object.poll(timeout) + result_sockets = [r[0] for r in result_list] + for s in sockets: + self.failIf(s in result_sockets) + # Now connect the sockets, but DO NOT register them again + for s in sockets: + s.setblocking(0) + s.connect( ('localhost', 80) ) + # Now poll again, to see if the poll object has recognised that the sockets are now connected + result_list = poll_object.poll(timeout) + result_sockets = [r[0] for r in result_list] + for s in sockets: + self.failUnless(s in result_sockets) + + def testUnregisterRaisesKeyError(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + poll_object = select.poll() + try: + poll_object.unregister(s) + except KeyError: + pass + else: + self.fail("Unregistering socket that is not registered should have raised KeyError") + +class TestPipes(unittest.TestCase): + + verbose = 1 + + def test(self): + import sys + from test.test_support import verbose + if sys.platform[:3] in ('win', 'mac', 'os2', 'riscos'): + if verbose: + print "Can't test select easily on", sys.platform + return + cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done' + p = os.popen(cmd, 'r') + for tout in (0, 1, 2, 4, 8, 16) + (None,)*10: + if verbose: + print 'timeout =', tout + rfd, wfd, xfd = select.select([p], [], [], tout) + if (rfd, wfd, xfd) == ([], [], []): + continue + if (rfd, wfd, xfd) == ([p], [], []): + line = p.readline() + if verbose: + print repr(line) + if not line: + if verbose: + print 'EOF' + break + continue + self.fail('Unexpected return values from select(): %s' % str(rfd, wfd, xfd)) + p.close() + +def test_main(): + tests = [ + TestSelectInvalidParameters, + TestSelectClientSocket, + TestPollClientSocket, + ] + if sys.platform[:4] != 'java': + tests.append(TestPipes) + suites = [unittest.makeSuite(klass, 'test') for klass in tests] + main_suite = unittest.TestSuite(suites) + runner = unittest.TextTestRunner(verbosity=100) + runner.run(main_suite) + +if __name__ == "__main__": + test_main() Copied: trunk/jython/Lib/test/test_select_new.py (from rev 3255, trunk/sandbox/kennedya/asynch_sockets/test/test_select_new.py) =================================================================== --- trunk/jython/Lib/test/test_select_new.py (rev 0) +++ trunk/jython/Lib/test/test_select_new.py 2007-06-18 06:17:10 UTC (rev 3256) @@ -0,0 +1,276 @@ +""" +AMAK: 20050515: This module is a brand new test_select module, which gives much wider coverage. +""" + +import sys +import time +import unittest + +import socket +import select + +NOT_READY, READY = 0, 1 + +SERVER_ADDRESS = ("localhost", 54321) + +DATA_CHUNK_SIZE = 1000 ; DATA_CHUNK = "." * DATA_CHUNK_SIZE + +# +# The timing of these tests depends on the how the unerlying OS socket library +# handles buffering. These values may need tweaking for different platforms +# +# The fundamental problem is that there is no reliable way to fill a socket with bytes +# + +if sys.platform[:4] == 'java': + SELECT_TIMEOUT = 0 +else: + # zero select timeout fails these tests on cpython (on windows 2003 anyway) + SELECT_TIMEOUT = 0.001 + +READ_TIMEOUT = 5 + +class AsynchronousServer: + + def __init__(self): + self.server_socket = None + + def create_socket(self): + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setblocking(0) + self.server_socket.bind(SERVER_ADDRESS) + self.server_socket.listen(5) + try: + self.server_socket.accept() + except socket.error: + pass + + def verify_acceptable_status(self, expected_acceptability): + actual_acceptability = NOT_READY + rfds, wfds, xfds = select.select([self.server_socket], [], [], SELECT_TIMEOUT) + if self.server_socket in rfds: + actual_acceptability = READY + assert actual_acceptability == expected_acceptability, "Server socket should %sbe acceptable" % {NOT_READY:'not ',READY:''}[expected_acceptability] + + def accept_connection(self): + rfds, wfds, xfds = select.select([self.server_socket], [], [], SELECT_TIMEOUT) + assert self.server_socket in rfds, "Server socket had no pending connections" + new_socket, address = self.server_socket.accept() + return AsynchronousHandler(new_socket) + + def close(self): + self.server_socket.close() + +class PeerImpl: + + def fill_outchannel(self): + """ + This implementation is sub-optimal. + It is reliant on how the OS handles the socket buffers. + """ + total_bytes = 0 + while 1: + try: + rfds, wfds, xfds = select.select([], [self.socket], [], SELECT_TIMEOUT) + if self.socket in wfds: + bytes_sent = self.socket.send(DATA_CHUNK) + total_bytes += bytes_sent + else: + return total_bytes + except socket.error, se: + if se.value == 10035: + continue + raise se + + def read_inchannel(self, expected): + buf_size = expected ; results = "" ; start = time.time() + while 1: + if (expected - len(results)) < buf_size: + buf_size = expected - len(results) + rfds, wfds, xfds = select.select([self.socket], [], [], SELECT_TIMEOUT) + if self.socket in rfds: + recvd_bytes = self.socket.recv(buf_size) + if len(recvd_bytes): + results += recvd_bytes + if len(results) == expected: + return results + else: + stop = time.time() + if (stop - start) > READ_TIMEOUT: + raise Exception("Exceeded alloted time (%1.3lf > %1.3lf) to read %d bytes: got %d" % ((stop-start), READ_TIMEOUT, expected, len(results))) + + def verify_status(self, expected_readability, expected_writability): + actual_readability, actual_writability = NOT_READY, NOT_READY + rfds, wfds, xfds = select.select([self.socket], [self.socket], [], SELECT_TIMEOUT) + if self.socket in rfds: + actual_readability = READY + if self.socket in wfds: + actual_writability = READY + assert actual_readability == expected_readability, "Socket should %sbe ready for reading: %s" % ({NOT_READY:'not ',READY:''}[expected_readability], rfds) + assert actual_writability == expected_writability, "Socket should %sbe ready for writing: %s" % ({NOT_READY:'not ',READY:''}[expected_writability], wfds) + + def fileno(self): + return self.socket.fileno() + + def close(self): + self.socket.close() + +class AsynchronousHandler(PeerImpl): + + def __init__(self, new_socket): + self.socket = new_socket + self.socket.setblocking(0) + +class AsynchronousClient(PeerImpl): + + def __init__(self): + self.socket = None + self.connected = 0 + + def create_socket(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setblocking(0) + + def start_connect(self): + result = self.socket.connect_ex(SERVER_ADDRESS) + if result == 0: + self.connected = 1 + + def finish_connect(self): + if self.connected: + return + rfds, wfds, xfds = select.select([], [self.socket], [], SELECT_TIMEOUT) + assert self.socket in wfds, "Client socket incomplete connect" + +def log(message): + print message + +class TestSelect(unittest.TestCase): + + def test000_CreateSockets(self): + # Create the server + TestSelect.server_socket = AsynchronousServer() + TestSelect.server_socket.create_socket() + + # Create the client + TestSelect.client_socket = AsynchronousClient() + TestSelect.client_socket.create_socket() + + def test100_ServerSocketNoPendingConnections(self): + # Check the server is not marked "acceptable" + TestSelect.server_socket.verify_acceptable_status(NOT_READY) + + def test110_ServerSocketPendingConnections(self): + # Start the client connection process + TestSelect.client_sock... [truncated message content] |
From: <am...@us...> - 2007-06-24 18:17:49
|
Revision: 3270 http://svn.sourceforge.net/jython/?rev=3270&view=rev Author: amak Date: 2007-06-24 11:17:44 -0700 (Sun, 24 Jun 2007) Log Message: ----------- 1. Fixed a bug in the interpretation of select.poll().poll() timeouts. Passing a float (as the number of milliseconds) was causing incorrect behaviour. 2. Changed a unit test for registration of sockets so that it creates and manages its own TCP server, rather than expecting a server to exist in the test environment. Modified Paths: -------------- trunk/jython/Lib/select.py trunk/jython/Lib/test/test_select.py Modified: trunk/jython/Lib/select.py =================================================================== --- trunk/jython/Lib/select.py 2007-06-24 17:57:58 UTC (rev 3269) +++ trunk/jython/Lib/select.py 2007-06-24 18:17:44 UTC (rev 3270) @@ -9,6 +9,14 @@ import socket +try: + import errno + ERRNO_EINVAL = errno.EINVAL + ERRNO_ENOTSOCK = errno.ENOTSOCK +except ImportError: + ERRNO_EINVAL = 22 + ERRNO_ENOTSOCK = 88 + class error(Exception): pass POLLIN = 1 @@ -37,7 +45,7 @@ return socket_object.getchannel() except: return None - raise error("Object '%s' is not watchable" % socket_object, 10038) + raise error("Object '%s' is not watchable" % socket_object, ERRNO_ENOTSOCK) def _register_channel(self, socket_object, channel, mask): jmask = 0 @@ -78,14 +86,19 @@ self.chanmap[channel][1].cancel() del self.chanmap[channel] - def _dopoll(self, timeout=None): + def _dopoll(self, timeout): if timeout is None or timeout < 0: self.selector.select() - elif timeout == 0: - self.selector.selectNow() else: - # No multiplication required: both cpython and java use millisecond timeouts - self.selector.select(timeout) + try: + timeout = int(timeout) + if timeout == 0: + self.selector.selectNow() + else: + # No multiplication required: both cpython and java use millisecond timeouts + self.selector.select(timeout) + except ValueError, vx: + raise error("poll timeout must be a number of milliseconds or None", ERRNO_EINVAL) # The returned selectedKeys cannot be used from multiple threads! return self.selector.selectedKeys() @@ -117,7 +130,7 @@ except Exception, x: raise TypeError("Select timeout value must be a number or None") if value < 0: - raise error("Select timeout value cannot be negative", 10022) + raise error("Select timeout value cannot be negative", ERRNO_EINVAL) if floatvalue < 0.000001: return 0 return int(floatvalue * 1000) # Convert to milliseconds Modified: trunk/jython/Lib/test/test_select.py =================================================================== --- trunk/jython/Lib/test/test_select.py 2007-06-24 17:57:58 UTC (rev 3269) +++ trunk/jython/Lib/test/test_select.py 2007-06-24 18:17:44 UTC (rev 3270) @@ -116,46 +116,12 @@ self.failIf(s in rfd) self.failIf(s in wfd) -def check_server_running_on_localhost_port(port_number): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - s.connect( ('localhost', port_number) ) - s.close() - except: - return 0 - return 1 - class TestPollClientSocket(unittest.TestCase): def testEventConstants(self): for event_name in ['IN', 'OUT', 'PRI', 'ERR', 'HUP', 'NVAL', ]: self.failUnless(hasattr(select, 'POLL%s' % event_name)) - def testSocketRegisteredBeforeConnected(self): - # You MUST be running a server on port 80 for this one to work - if not check_server_running_on_localhost_port(80): - print "Unable to run testSocketRegisteredBeforeConnected: no server on port 80" - return - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM) for x in range(5)] - timeout = 1 # Can't wait forever - poll_object = select.poll() - for s in sockets: - # Register the sockets before they are connected - poll_object.register(s, select.POLLOUT) - result_list = poll_object.poll(timeout) - result_sockets = [r[0] for r in result_list] - for s in sockets: - self.failIf(s in result_sockets) - # Now connect the sockets, but DO NOT register them again - for s in sockets: - s.setblocking(0) - s.connect( ('localhost', 80) ) - # Now poll again, to see if the poll object has recognised that the sockets are now connected - result_list = poll_object.poll(timeout) - result_sockets = [r[0] for r in result_list] - for s in sockets: - self.failUnless(s in result_sockets) - def testUnregisterRaisesKeyError(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) poll_object = select.poll() @@ -166,6 +132,33 @@ else: self.fail("Unregistering socket that is not registered should have raised KeyError") +# +# using the test_socket thread based server/client management, for convenience. +# + +import test_socket + +class ThreadedPollClientSocket(test_socket.ThreadedTCPSocketTest): + + def testSocketRegisteredBeforeConnected(self): + self.cli_conn = self.serv.accept() + + def _testSocketRegisteredBeforeConnected(self): + timeout = 1000 # milliseconds + poll_object = select.poll() + # Register the socket before it is connected + poll_object.register(self.cli, select.POLLOUT) + result_list = poll_object.poll(timeout) + result_sockets = [r[0] for r in result_list] + self.failIf(self.cli in result_sockets, "Unconnected client socket should not have been selectable") + # Now connect the socket, but DO NOT register it again + self.cli.setblocking(0) + self.cli.connect( (test_socket.HOST, test_socket.PORT) ) + # Now poll again, to check that the poll object has recognised that the socket is now connected + result_list = poll_object.poll(timeout) + result_sockets = [r[0] for r in result_list] + self.failUnless(self.cli in result_sockets, "Connected client socket should have been selectable") + class TestPipes(unittest.TestCase): verbose = 1 @@ -202,6 +195,7 @@ TestSelectInvalidParameters, TestSelectClientSocket, TestPollClientSocket, + ThreadedPollClientSocket, ] if sys.platform[:4] != 'java': tests.append(TestPipes) This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <am...@us...> - 2007-07-03 19:53:06
|
Revision: 3280 http://svn.sourceforge.net/jython/?rev=3280&view=rev Author: amak Date: 2007-07-03 12:53:01 -0700 (Tue, 03 Jul 2007) Log Message: ----------- 1. Found and fixed a bug whereby timeouts on client connects were not being honoured. Added a unit test to check that they are. 2. Added several jython-specific unit tests to ensure that java->python exception mapping is correct. Modified Paths: -------------- trunk/jython/Lib/select.py trunk/jython/Lib/socket.py trunk/jython/Lib/test/test_socket.py Modified: trunk/jython/Lib/select.py =================================================================== --- trunk/jython/Lib/select.py 2007-07-01 18:06:51 UTC (rev 3279) +++ trunk/jython/Lib/select.py 2007-07-03 19:53:01 UTC (rev 3280) @@ -34,7 +34,7 @@ try: return _exception_map[(exc.__class__, circumstance)](exc) except KeyError: - return error('Unmapped java exception: %s' % exc.toString()) + return error(-1, 'Unmapped java exception: %s' % exc.toString()) POLLIN = 1 POLLOUT = 2 Modified: trunk/jython/Lib/socket.py =================================================================== --- trunk/jython/Lib/socket.py 2007-07-01 18:06:51 UTC (rev 3279) +++ trunk/jython/Lib/socket.py 2007-07-03 19:53:01 UTC (rev 3280) @@ -41,6 +41,7 @@ import java.nio.channels.IllegalBlockingModeException import java.nio.channels.ServerSocketChannel import java.nio.channels.SocketChannel +import java.nio.channels.UnresolvedAddressException import javax.net.ssl.SSLSocketFactory import org.python.core.PyFile @@ -72,9 +73,11 @@ (java.io.InterruptedIOException, ALL) : lambda exc: timeout('timed out'), (java.net.BindException, ALL) : lambda exc: error(ERRNO_EACCES, 'Permission denied'), -(java.net.ConnectException, ALL) : lambda exc: error( (ERRNO_ECONNREFUSED, 'Connection refused') ), +(java.net.ConnectException, ALL) : lambda exc: error(ERRNO_ECONNREFUSED, 'Connection refused'), (java.net.SocketTimeoutException, ALL) : lambda exc: timeout('timed out'), (java.net.UnknownHostException, ALL) : lambda exc: gaierror(ERRNO_EGETADDRINFOFAILED, 'getaddrinfo failed'), +(java.nio.channels.UnresolvedAddressException, ALL) : lambda exc: gaierror(ERRNO_EGETADDRINFOFAILED, 'getaddrinfo failed'), + } def would_block_error(exc=None): @@ -84,7 +87,7 @@ try: return _exception_map[(exc.__class__, circumstance)](exc) except KeyError: - return error('Unmapped java exception: %s' % exc.toString()) + return error(-1, 'Unmapped java exception: %s' % exc.toString()) MODE_BLOCKING = 'block' MODE_NONBLOCKING = 'nonblock' @@ -123,8 +126,8 @@ if self.mode == MODE_NONBLOCKING: self.jchannel.configureBlocking(0) if self.mode == MODE_TIMEOUT: - # self.channel.configureBlocking(0) - self.jsocket.setSoTimeout(int(timeout*1000)) + self._timeout_millis = int(timeout*1000) + self.jsocket.setSoTimeout(self._timeout_millis) def close1(self): self.jsocket.close() @@ -173,7 +176,10 @@ def connect(self, host, port): self.host = host self.port = port - self.jchannel.connect(java.net.InetSocketAddress(self.host, self.port)) + if self.mode == MODE_TIMEOUT: + self.jsocket.connect(java.net.InetSocketAddress(self.host, self.port), self._timeout_millis) + else: + self.jchannel.connect(java.net.InetSocketAddress(self.host, self.port)) def finish_connect(self): return self.jchannel.finishConnect() Modified: trunk/jython/Lib/test/test_socket.py =================================================================== --- trunk/jython/Lib/test/test_socket.py 2007-07-01 18:06:51 UTC (rev 3279) +++ trunk/jython/Lib/test/test_socket.py 2007-07-03 19:53:01 UTC (rev 3280) @@ -711,7 +711,7 @@ def _testConnectWithLocalBind(self): # Testing blocking connect with local bind - self.cli.settimeout(10) + self.cli.settimeout(1) self.cli.bind( (HOST, PORT-1) ) self.cli.connect((HOST, PORT)) bound_host, bound_port = self.cli.getsockname() @@ -886,6 +886,22 @@ if not ok: self.fail("accept() returned success when we did not expect it") +class TCPClientTimeoutTest(ThreadedTCPSocketTest): + + def testTCPClientTimeout(self): + pass # i.e. do not accept + + def _testTCPClientTimeout(self): + try: + self.cli.settimeout(0.1) + self.cli.connect( (HOST, PORT) ) + except socket.timeout, st: + pass + except Exception, x: + self.fail("Client socket timeout should have raised socket.timeout, not %s" % str(x)) + else: + self.fail("Client socket timeout should have raised socket.timeout") + # # AMAK: 20070307 # Corrected the superclass of UDPTimeoutTest @@ -922,6 +938,8 @@ self.assert_(issubclass(socket.gaierror, socket.error)) self.assert_(issubclass(socket.timeout, socket.error)) +class TestJythonExceptions(unittest.TestCase): + def testHostNotFound(self): try: socket.gethostbyname("doesnotexist") @@ -930,6 +948,49 @@ except Exception, x: self.fail("Get host name for non-existent host raised wrong exception: %s" % x) + def testConnectionRefused(self): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # This port should not be open at this time + s.connect( (HOST, PORT) ) + except socket.error, se: + self.failUnlessEqual(se[0], errno.ECONNREFUSED) + except Exception, x: + self.fail("Connection to non-existent host/port raised wrong exception: %s" % x) + else: + self.fail("Socket (%s,%s) should not have been listening at this time" % (HOST, PORT)) + + def testBindException(self): + # First bind to the target port + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind( (HOST, PORT) ) + s.listen() + try: + try: + # And then try to bind again + t = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + t.bind( (HOST, PORT) ) + t.listen() + except socket.error, se: + self.failUnlessEqual(se[0], errno.EACCES) + except Exception, x: + self.fail("Binding to already bound host/port raised wrong exception: %s" % x) + else: + self.fail("Binding to already bound host/port should have raised exception") + finally: + s.close() + + def testUnresolvedAddress(self): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect( ('non.existent.server', PORT) ) + except socket.gaierror, gaix: + self.failUnlessEqual(gaix[0], errno.EGETADDRINFOFAILED) + except Exception, x: + self.fail("Get host name for non-existent host raised wrong exception: %s" % x) + else: + self.fail("Get host name for non-existent host should have raised exception") + class TestAddressParameters: def testBindNonTupleEndpointRaisesTypeError(self): @@ -971,6 +1032,7 @@ GeneralModuleTests, BasicTCPTest, TCPTimeoutTest, + TCPClientTimeoutTest, TestExceptions, TestTCPAddressParameters, TestUDPAddressParameters, @@ -985,6 +1047,8 @@ ] if hasattr(socket, "socketpair"): tests.append(BasicSocketPairTest) + if sys.platform[:4] == 'java': + tests.append(TestJythonExceptions) suites = [unittest.makeSuite(klass, 'test') for klass in tests] main_suite = unittest.TestSuite(suites) runner = unittest.TextTestRunner(verbosity=100) This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <am...@us...> - 2007-07-16 19:13:27
|
Revision: 3319 http://svn.sourceforge.net/jython/?rev=3319&view=rev Author: amak Date: 2007-07-16 12:13:26 -0700 (Mon, 16 Jul 2007) Log Message: ----------- . Fixed a bug with FileWrappers, whereby closing the FileWrapper incorrectly caused closing of the underlying socket. . Wrapped several more methods in try .. except for exception mapping . Removed unnecessary lambdas from exception mapping table . Completed the exception mapping table with all java.net and java.nio exceptions . Removed errno module simulation: was unnecessary (jython 2.1 users can easily create their own errno module) . Added constants SHUT_RD, SHUT_WR and SHUT_RDWR, as used by shutdown() . Added proper socket shutdown support, as opposed to the old method of closing the [In|Out]putStreams . Removed some comments and commented-out code . Did some basic refactoring of UDP connect methods, to reduce code duplication . Added two more unit tests for closing sockets and FileWrappers . Added more unit tests for exception mapping Modified Paths: -------------- trunk/jython/Lib/socket.py trunk/jython/Lib/test/test_socket.py Modified: trunk/jython/Lib/socket.py =================================================================== --- trunk/jython/Lib/socket.py 2007-07-16 02:44:37 UTC (rev 3318) +++ trunk/jython/Lib/socket.py 2007-07-16 19:13:26 UTC (rev 3319) @@ -16,50 +16,67 @@ _defaulttimeout = None +import errno +import jarray +import string +import sys import threading import time import types -import jarray -import string -import sys +# Java.io classes import java.io.BufferedInputStream import java.io.BufferedOutputStream +# Java.io exceptions import java.io.InterruptedIOException +import java.io.IOException + +# Java.lang classes +import java.lang.String +# Java.lang exceptions import java.lang.Exception -import java.lang.String -import java.net.BindException -import java.net.ConnectException + +# Java.net classes import java.net.DatagramPacket import java.net.InetAddress import java.net.InetSocketAddress import java.net.Socket +# Java.net exceptions +import java.net.BindException +import java.net.ConnectException +import java.net.NoRouteToHostException +import java.net.PortUnreachableException +import java.net.ProtocolException +import java.net.SocketException import java.net.SocketTimeoutException import java.net.UnknownHostException + +# Java.nio classes import java.nio.ByteBuffer import java.nio.channels.DatagramChannel -import java.nio.channels.IllegalBlockingModeException import java.nio.channels.ServerSocketChannel import java.nio.channels.SocketChannel +# Java.nio exceptions +import java.nio.channels.AlreadyConnectedException +import java.nio.channels.AsynchronousCloseException +import java.nio.channels.CancelledKeyException +import java.nio.channels.ClosedByInterruptException +import java.nio.channels.ClosedChannelException +import java.nio.channels.ClosedSelectorException +import java.nio.channels.ConnectionPendingException +import java.nio.channels.IllegalBlockingModeException +import java.nio.channels.IllegalSelectorException +import java.nio.channels.NoConnectionPendingException +import java.nio.channels.NonReadableChannelException +import java.nio.channels.NonWritableChannelException +import java.nio.channels.NotYetBoundException +import java.nio.channels.NotYetConnectedException import java.nio.channels.UnresolvedAddressException +import java.nio.channels.UnsupportedAddressTypeException + import javax.net.ssl.SSLSocketFactory import org.python.core.PyFile -try: - import errno - ERRNO_EWOULDBLOCK = errno.EWOULDBLOCK - ERRNO_EACCES = errno.EACCES - ERRNO_ECONNREFUSED = errno.ECONNREFUSED - ERRNO_EINPROGRESS = errno.EINPROGRESS - ERRNO_EGETADDRINFOFAILED = errno.EGETADDRINFOFAILED -except ImportError: - # Support jython 2.1 - ERRNO_EWOULDBLOCK = 11 - ERRNO_EACCES = 13 - ERRNO_ECONNREFUSED = 111 - ERRNO_EINPROGRESS = 115 - ERRNO_EGETADDRINFOFAILED = 20001 - class error(Exception): pass class herror(error): pass class gaierror(error): pass @@ -71,23 +88,48 @@ # (<javaexception>, <circumstance>) : lambda: <code that raises the python equivalent> -(java.io.InterruptedIOException, ALL) : lambda exc: timeout('timed out'), -(java.net.BindException, ALL) : lambda exc: error(ERRNO_EACCES, 'Permission denied'), -(java.net.ConnectException, ALL) : lambda exc: error(ERRNO_ECONNREFUSED, 'Connection refused'), -(java.net.SocketTimeoutException, ALL) : lambda exc: timeout('timed out'), -(java.net.UnknownHostException, ALL) : lambda exc: gaierror(ERRNO_EGETADDRINFOFAILED, 'getaddrinfo failed'), -(java.nio.channels.UnresolvedAddressException, ALL) : lambda exc: gaierror(ERRNO_EGETADDRINFOFAILED, 'getaddrinfo failed'), +(java.io.IOException, ALL) : error(errno.ECONNRESET, 'Software caused connection abort'), +(java.io.InterruptedIOException, ALL) : timeout('timed out'), +(java.net.BindException, ALL) : error(errno.EADDRINUSE, 'Address already in use'), +(java.net.ConnectException, ALL) : error(errno.ECONNREFUSED, 'Connection refused'), +(java.net.NoRouteToHostException, ALL) : error(-1, 'Unmapped exception: java.net.NoRouteToHostException'), +(java.net.PortUnreachableException, ALL) : error(-1, 'Unmapped exception: java.net.PortUnreachableException'), +(java.net.ProtocolException, ALL) : error(-1, 'Unmapped exception: java.net.ProtocolException'), +(java.net.SocketException, ALL) : error(-1, 'Unmapped exception: java.net.SocketException'), +(java.net.SocketTimeoutException, ALL) : timeout('timed out'), +(java.net.UnknownHostException, ALL) : gaierror(errno.EGETADDRINFOFAILED, 'getaddrinfo failed'), + +(java.nio.channels.AlreadyConnectedException, ALL) : error(errno.EISCONN, 'Socket is already connected'), +(java.nio.channels.AsynchronousCloseException, ALL) : error(-1, 'Unmapped exception: java.nio.AsynchronousCloseException'), +(java.nio.channels.CancelledKeyException, ALL) : error(-1, 'Unmapped exception: java.nio.CancelledKeyException'), +(java.nio.channels.ClosedByInterruptException, ALL) : error(-1, 'Unmapped exception: java.nio.ClosedByInterruptException'), +(java.nio.channels.ClosedChannelException, ALL) : error(errno.EPIPE, 'Socket closed'), +(java.nio.channels.ClosedSelectorException, ALL) : error(-1, 'Unmapped exception: java.nio.ClosedSelectorException'), +(java.nio.channels.ConnectionPendingException, ALL) : error(-1, 'Unmapped exception: java.nio.ConnectionPendingException'), +(java.nio.channels.IllegalBlockingModeException, ALL) : error(-1, 'Unmapped exception: java.nio.IllegalBlockingModeException'), +(java.nio.channels.IllegalSelectorException, ALL) : error(-1, 'Unmapped exception: java.nio.IllegalSelectorException'), +(java.nio.channels.NoConnectionPendingException, ALL) : error(-1, 'Unmapped exception: java.nio.NoConnectionPendingException'), +(java.nio.channels.NonReadableChannelException, ALL) : error(-1, 'Unmapped exception: java.nio.NonReadableChannelException'), +(java.nio.channels.NonWritableChannelException, ALL) : error(-1, 'Unmapped exception: java.nio.NonWritableChannelException'), +(java.nio.channels.NotYetBoundException, ALL) : error(-1, 'Unmapped exception: java.nio.NotYetBoundException'), +(java.nio.channels.NotYetConnectedException, ALL) : error(-1, 'Unmapped exception: java.nio.NotYetConnectedException'), +(java.nio.channels.UnresolvedAddressException, ALL) : gaierror(errno.EGETADDRINFOFAILED, 'getaddrinfo failed'), +(java.nio.channels.UnsupportedAddressTypeException, ALL) : error(-1, 'Unmapped exception: java.nio.UnsupportedAddressTypeException'), + } def would_block_error(exc=None): - return error( (ERRNO_EWOULDBLOCK, 'The socket operation could not complete without blocking') ) + return error(errno.EWOULDBLOCK, 'The socket operation could not complete without blocking') def _map_exception(exc, circumstance=ALL): +# print "Mapping exception: %s" % exc try: - return _exception_map[(exc.__class__, circumstance)](exc) + mapped_exception = _exception_map[(exc.__class__, circumstance)] + mapped_exception.java_exception = exc + return mapped_exception except KeyError: - return error(-1, 'Unmapped java exception: %s' % exc.toString()) + return error(-1, 'Unmapped java exception: <%s:%s>' % (exc.toString(), circumstance)) MODE_BLOCKING = 'block' MODE_NONBLOCKING = 'nonblock' @@ -95,6 +137,10 @@ _permitted_modes = (MODE_BLOCKING, MODE_NONBLOCKING, MODE_TIMEOUT) +SHUT_RD = 0 +SHUT_WR = 1 +SHUT_RDWR = 2 + class _nio_impl: timeout = None @@ -152,6 +198,12 @@ # close = close3 # close = close4 + def shutdownInput(self): + self.jsocket.shutdownInput() + + def shutdownOutput(self): + self.jsocket.shutdownOutput() + def getchannel(self): return self.jchannel @@ -184,9 +236,6 @@ def finish_connect(self): return self.jchannel.finishConnect() - def close(self): - _nio_impl.close(self) - class _server_socket_impl(_nio_impl): def __init__(self, host, port, backlog, reuse_addr): @@ -200,23 +249,17 @@ self.jsocket.bind(bindaddr, backlog) def accept(self): - try: - if self.mode in (MODE_BLOCKING, MODE_NONBLOCKING): - new_cli_chan = self.jchannel.accept() - if new_cli_chan != None: - return _client_socket_impl(new_cli_chan.socket()) - else: - return None + if self.mode in (MODE_BLOCKING, MODE_NONBLOCKING): + new_cli_chan = self.jchannel.accept() + if new_cli_chan != None: + return _client_socket_impl(new_cli_chan.socket()) else: - # In timeout mode now - new_cli_sock = self.jsocket.accept() - return _client_socket_impl(new_cli_sock) - except java.lang.Exception, jlx: - raise _map_exception(jlx) + return None + else: + # In timeout mode now + new_cli_sock = self.jsocket.accept() + return _client_socket_impl(new_cli_sock) - def close(self): - _nio_impl.close(self) - class _datagram_socket_impl(_nio_impl): def __init__(self, port=None, address=None, reuse_addr=0): @@ -248,6 +291,7 @@ 'getfqdn', 'gethostbyaddr', 'gethostbyname', 'gethostname', 'socket', 'getaddrinfo', 'getdefaulttimeout', 'setdefaulttimeout', 'has_ipv6', 'htons', 'htonl', 'ntohs', 'ntohl', + 'SHUT_RD', 'SHUT_WR', 'SHUT_RDWR', ] AF_INET = 2 @@ -343,7 +387,7 @@ floatvalue = float(value) except: raise TypeError('Socket timeout value must be a number or None') - if floatvalue < 0: + if floatvalue < 0.0: raise ValueError("Socket timeout value cannot be negative") if floatvalue < 0.000001: return 0.0 @@ -396,9 +440,6 @@ if not self.sock_impl: return None return self.sock_impl.getchannel() -# if hasattr(self.sock_impl, 'getchannel'): -# return self.sock_impl.getchannel() -# raise error('Operation not implemented on this JVM') fileno = getchannel @@ -421,7 +462,6 @@ local_addr = None server = 0 file_count = 0 - #reuse_addr = 1 reuse_addr = 0 def bind(self, addr): @@ -445,11 +485,6 @@ except java.lang.Exception, jlx: raise _map_exception(jlx) -# -# The following has information on a java.lang.NullPointerException problem I'm having -# -# http://developer.java.sun.com/developer/bugParade/bugs/4801882.html - def accept(self): "This signifies a server socket" try: @@ -495,7 +530,7 @@ if self.sock_impl.finish_connect(): self._setup(self.sock_impl) return 0 - return ERRNO_EINPROGRESS + return errno.EINPROGRESS def _setup(self, sock): self.sock_impl = sock @@ -506,7 +541,7 @@ def recv(self, n): try: - if not self.sock_impl: raise error('Socket not open') + if not self.sock_impl: raise error(errno.ENOTCONN, 'Socket is not connected') if self.sock_impl.jchannel.isConnectionPending(): self.sock_impl.jchannel.finishConnect() data = jarray.zeros(n, 'b') @@ -525,33 +560,41 @@ return self.recv(n), None def send(self, s): - if not self.sock_impl: raise error('Socket not open') - if self.sock_impl.jchannel.isConnectionPending(): - self.sock_impl.jchannel.finishConnect() - #n = len(s) - numwritten = self.sock_impl.write(s) - return numwritten + try: + if not self.sock_impl: raise error(errno.ENOTCONN, 'Socket is not connected') + if self.sock_impl.jchannel.isConnectionPending(): + self.sock_impl.jchannel.finishConnect() + numwritten = self.sock_impl.write(s) + return numwritten + except java.lang.Exception, jlx: + raise _map_exception(jlx) sendall = send def getsockname(self): - if not self.sock_impl: - host, port = self.local_addr or ("", 0) - host = java.net.InetAddress.getByName(host).getHostAddress() - else: - if self.server: - host = self.sock_impl.jsocket.getInetAddress().getHostAddress() + try: + if not self.sock_impl: + host, port = self.local_addr or ("", 0) + host = java.net.InetAddress.getByName(host).getHostAddress() else: - host = self.sock_impl.jsocket.getLocalAddress().getHostAddress() - port = self.sock_impl.jsocket.getLocalPort() - return (host, port) + if self.server: + host = self.sock_impl.jsocket.getInetAddress().getHostAddress() + else: + host = self.sock_impl.jsocket.getLocalAddress().getHostAddress() + port = self.sock_impl.jsocket.getLocalPort() + return (host, port) + except java.lang.Exception, jlx: + raise _map_exception(jlx) def getpeername(self): - assert self.sock_impl - assert not self.server - host = self.sock_impl.jsocket.getInetAddress().getHostAddress() - port = self.sock_impl.jsocket.getPort() - return (host, port) + try: + assert self.sock_impl + assert not self.server + host = self.sock_impl.jsocket.getInetAddress().getHostAddress() + port = self.sock_impl.jsocket.getPort() + return (host, port) + except java.lang.Exception, jlx: + raise _map_exception(jlx) def setsockopt(self, level, optname, value): if optname == SO_REUSEADDR: @@ -578,12 +621,11 @@ class FileWrapper: def __init__(self, socket, file): - self.socket = socket - self.sock = socket.sock_impl + self.socket = socket self.istream = socket.istream self.ostream = socket.ostream - self.file = file + self.file = file self.read = file.read self.readline = file.readline self.readlines = file.readlines @@ -597,51 +639,62 @@ self.socket.file_count += 1 def close(self): - if self.file.closed: + if self.closed: # Already closed return self.socket.file_count -= 1 - self.file.close() - self.closed = self.file.closed + # AMAK: 20070715: Cannot close the PyFile, because closing + # it causes the InputStream and OutputStream to be closed. + # This in turn causes the underlying socket to be closed. + # This was always true for java.net sockets + # And continues to be true for java.nio sockets + # http://bugs.sun.com/bugdatabase/view_bug.do?bug_id=4717638 +# self.file.close() + istream = self.istream + ostream = self.ostream + self.istream = None + self.ostream = None +# self.closed = self.file.closed + self.closed = 1 - if self.socket.file_count == 0 and self.socket.sock_impl == 0: + if self.socket.file_count == 0 and self.socket.sock_impl is None: # This is the last file Only close the socket and streams # if there are no outstanding files left. - if self.sock: - self.sock.close() - if self.istream: - self.istream.close() - if self.ostream: - self.ostream.close() + istream.close() + ostream.close() def shutdown(self, how): - assert how in (0, 1, 2) + assert how in (SHUT_RD, SHUT_WR, SHUT_RDWR) assert self.sock_impl - if how in (0, 2): + if how in (SHUT_RD, SHUT_RDWR): + self.sock_impl.shutdownInput() + if how in (SHUT_WR, SHUT_RDWR): + self.sock_impl.shutdownOutput() + + def close(self): + try: + if not self.sock_impl: + return + sock_impl = self.sock_impl + istream = self.istream + ostream = self.ostream + self.sock_impl = None self.istream = None - if how in (1, 2): self.ostream = None + # Only close the socket and streams if there are no + # outstanding files left. + if self.file_count == 0: + if istream: + istream.close() + if ostream: + ostream.close() + if sock_impl: + sock_impl.close() + except java.lang.Exception, jlx: + raise _map_exception(jlx) + - def close(self): - if not self.sock_impl: - return - sock = self.sock_impl - istream = self.istream - ostream = self.ostream - self.sock_impl = 0 - self.istream = 0 - self.ostream = 0 - # Only close the socket and streams if there are no - # outstanding files left. - if self.file_count == 0: - if istream: - istream.close() - if ostream: - ostream.close() - if sock: - sock.close() - class _udpsocket(_nonblocking_api_mixin): def __init__(self): @@ -650,50 +703,56 @@ self.reuse_addr = 0 def bind(self, addr): - assert not self.sock_impl - host, port = _unpack_address_tuple(addr) - host_address = java.net.InetAddress.getByName(host) - self.sock_impl = _datagram_socket_impl(port, host_address, reuse_addr = self.reuse_addr) - self._config() + try: + assert not self.sock_impl + host, port = _unpack_address_tuple(addr) + host_address = java.net.InetAddress.getByName(host) + self.sock_impl = _datagram_socket_impl(port, host_address, reuse_addr = self.reuse_addr) + self._config() + except java.lang.Exception, jlx: + raise _map_exception(jlx) + def _do_connect(self, addr): + try: + host, port = _unpack_address_tuple(addr) + assert not self.addr + self.addr = addr + if not self.sock_impl: + self.sock_impl = _datagram_socket_impl() + self._config() + self.sock_impl.connect(host, port) + except java.lang.Exception, jlx: + raise _map_exception(jlx) + def connect(self, addr): - host, port = _unpack_address_tuple(addr) - assert not self.addr - if not self.sock_impl: - self.sock_impl = _datagram_socket_impl() - self._config() - self.sock_impl.connect(host, port) - self.addr = addr # convert host to InetAddress instance? + self._do_connect(addr) def connect_ex(self, addr): - host, port = _unpack_address_tuple(addr) - assert not self.addr - self.addr = addr - if not self.sock_impl: - self.sock_impl = _datagram_socket_impl() - self._config() - self.sock_impl.connect(host, port) - if self.sock_impl.finish_connect(): - return 0 - return ERRNO_EINPROGRESS + self._do_connect(addr) + if self.sock_impl.finish_connect(): + return 0 + return errno.EINPROGRESS def sendto(self, data, p1, p2=None): - if not p2: - flags, addr = 0, p1 - else: - flags, addr = 0, p2 - n = len(data) - if not self.sock_impl: - self.sock_impl = _datagram_socket_impl() - host, port = addr - bytes = java.lang.String(data).getBytes('iso-8859-1') - a = java.net.InetAddress.getByName(host) - packet = java.net.DatagramPacket(bytes, n, a, port) - self.sock_impl.send(packet) - return n + try: + if not p2: + flags, addr = 0, p1 + else: + flags, addr = 0, p2 + n = len(data) + if not self.sock_impl: + self.sock_impl = _datagram_socket_impl() + host, port = addr + bytes = java.lang.String(data).getBytes('iso-8859-1') + a = java.net.InetAddress.getByName(host) + packet = java.net.DatagramPacket(bytes, n, a, port) + self.sock_impl.send(packet) + return n + except java.lang.Exception, jlx: + raise _map_exception(jlx) def send(self, data): - assert self.addr + if not self.addr: raise error(errno.ENOTCONN, "Socket is not connected") return self.sendto(data, self.addr) def recvfrom(self, n): @@ -727,37 +786,52 @@ raise _map_exception(jlx) def getsockname(self): - assert self.sock_impl - host = self.sock_impl.jsocket.getLocalAddress().getHostName() - port = self.sock_impl.jsocket.getLocalPort() - return (host, port) + try: + assert self.sock_impl + host = self.sock_impl.jsocket.getLocalAddress().getHostName() + port = self.sock_impl.jsocket.getLocalPort() + return (host, port) + except java.lang.Exception, jlx: + raise _map_exception(jlx) def getpeername(self): - assert self.sock - host = self.sock_impl.jsocket.getInetAddress().getHostName() - port = self.sock_impl.jsocket.getPort() - return (host, port) + try: + assert self.sock + host = self.sock_impl.jsocket.getInetAddress().getHostName() + port = self.sock_impl.jsocket.getPort() + return (host, port) + except java.lang.Exception, jlx: + raise _map_exception(jlx) def __del__(self): self.close() def close(self): - if not self.sock_impl: - return - sock = self.sock_impl - self.sock_impl = None - sock.close() + try: + if not self.sock_impl: + return + sock = self.sock_impl + self.sock_impl = None + sock.close() + except java.lang.Exception, jlx: + raise _map_exception(jlx) def setsockopt(self, level, optname, value): - if optname == SO_REUSEADDR: - self.reuse_addr = value -# self.sock._setreuseaddress(value) + try: + if optname == SO_REUSEADDR: + self.reuse_addr = value +# self.sock._setreuseaddress(value) + except java.lang.Exception, jlx: + raise _map_exception(jlx) def getsockopt(self, level, optname): - if optname == SO_REUSEADDR: - return self.sock_impl._getreuseaddress() - else: - return None + try: + if optname == SO_REUSEADDR: + return self.sock_impl._getreuseaddress() + else: + return None + except java.lang.Exception, jlx: + raise _map_exception(jlx) SocketType = _tcpsocket SocketTypes = [_tcpsocket, _udpsocket] Modified: trunk/jython/Lib/test/test_socket.py =================================================================== --- trunk/jython/Lib/test/test_socket.py 2007-07-16 02:44:37 UTC (rev 3318) +++ trunk/jython/Lib/test/test_socket.py 2007-07-16 19:13:26 UTC (rev 3319) @@ -1,1059 +1,1121 @@ -from __future__ import nested_scopes - -""" -AMAK: 20050515: This module is the test_socket.py from cpython 2.4, ported to jython. -""" - -import unittest -#from test import test_support - -import errno -import socket -import select -import time -import thread, threading -import Queue -import sys -from weakref import proxy - -PORT = 50007 -HOST = 'localhost' -MSG = 'Michael Gilfix was here\n' -EIGHT_BIT_MSG = 'Bh\xed Al\xe1in \xd3 Cinn\xe9ide anseo\n' - -try: - True -except NameError: - True, False = 1, 0 - -class SocketTCPTest(unittest.TestCase): - - def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.serv.bind((HOST, PORT)) - self.serv.listen(1) - - def tearDown(self): - self.serv.close() - self.serv = None - -class SocketUDPTest(unittest.TestCase): - - def setUp(self): - self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.serv.bind((HOST, PORT)) - - def tearDown(self): - self.serv.close() - self.serv = None - -class ThreadableTest: - """Threadable Test class - - The ThreadableTest class makes it easy to create a threaded - client/server pair from an existing unit test. To create a - new threaded class from an existing unit test, use multiple - inheritance: - - class NewClass (OldClass, ThreadableTest): - pass - - This class defines two new fixture functions with obvious - purposes for overriding: - - clientSetUp () - clientTearDown () - - Any new test functions within the class must then define - tests in pairs, where the test name is preceeded with a - '_' to indicate the client portion of the test. Ex: - - def testFoo(self): - # Server portion - - def _testFoo(self): - # Client portion - - Any exceptions raised by the clients during their tests - are caught and transferred to the main thread to alert - the testing framework. - - Note, the server setup function cannot call any blocking - functions that rely on the client thread during setup, - unless serverExplicityReady() is called just before - the blocking call (such as in setting up a client/server - connection and performing the accept() in setUp(). - """ - - def __init__(self): - # Swap the true setup function - self.__setUp = self.setUp - self.__tearDown = self.tearDown - self.setUp = self._setUp - self.tearDown = self._tearDown - - def serverExplicitReady(self): - """This method allows the server to explicitly indicate that - it wants the client thread to proceed. This is useful if the - server is about to execute a blocking routine that is - dependent upon the client thread during its setup routine.""" - self.server_ready.set() - - def _setUp(self): - self.server_ready = threading.Event() - self.client_ready = threading.Event() - self.done = threading.Event() - self.queue = Queue.Queue(1) - - # Do some munging to start the client test. - methodname = self.id() - i = methodname.rfind('.') - methodname = methodname[i+1:] - self.test_method_name = methodname - test_method = getattr(self, '_' + methodname) - self.client_thread = thread.start_new_thread( - self.clientRun, (test_method,)) - - self.__setUp() - if not self.server_ready.isSet(): - self.server_ready.set() - self.client_ready.wait() - - def _tearDown(self): - self.__tearDown() - self.done.wait() - - if not self.queue.empty(): - msg = self.queue.get() - self.fail(msg) - - def clientRun(self, test_func): - self.server_ready.wait() - self.client_ready.set() - self.clientSetUp() - if not callable(test_func): - raise TypeError, "test_func must be a callable function" - try: - test_func() - except Exception, strerror: - self.queue.put(strerror) - self.clientTearDown() - - def clientSetUp(self): - raise NotImplementedError, "clientSetUp must be implemented." - - def clientTearDown(self): - self.done.set() - if sys.platform[:4] != 'java': - # This causes the whole process to exit on jython - # Probably related to problems with daemon status of threads - thread.exit() - -class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest): - - def __init__(self, methodName='runTest'): - SocketTCPTest.__init__(self, methodName=methodName) - ThreadableTest.__init__(self) - - def clientSetUp(self): - self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - - def clientTearDown(self): - self.cli.close() - self.cli = None - ThreadableTest.clientTearDown(self) - -class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest): - - def __init__(self, methodName='runTest'): - SocketUDPTest.__init__(self, methodName=methodName) - ThreadableTest.__init__(self) - - def clientSetUp(self): - self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - -class SocketConnectedTest(ThreadedTCPSocketTest): - - def __init__(self, methodName='runTest'): - ThreadedTCPSocketTest.__init__(self, methodName=methodName) - - def setUp(self): - ThreadedTCPSocketTest.setUp(self) - # Indicate explicitly we're ready for the client thread to - # proceed and then perform the blocking call to accept - self.serverExplicitReady() - conn, addr = self.serv.accept() - self.cli_conn = conn - - def tearDown(self): - self.cli_conn.close() - self.cli_conn = None - ThreadedTCPSocketTest.tearDown(self) - - def clientSetUp(self): - ThreadedTCPSocketTest.clientSetUp(self) - self.cli.connect((HOST, PORT)) - self.serv_conn = self.cli - - def clientTearDown(self): - self.serv_conn.close() - self.serv_conn = None - ThreadedTCPSocketTest.clientTearDown(self) - -class SocketPairTest(unittest.TestCase, ThreadableTest): - - def __init__(self, methodName='runTest'): - unittest.TestCase.__init__(self, methodName=methodName) - ThreadableTest.__init__(self) - - def setUp(self): - self.serv, self.cli = socket.socketpair() - - def tearDown(self): - self.serv.close() - self.serv = None - - def clientSetUp(self): - pass - - def clientTearDown(self): - self.cli.close() - self.cli = None - ThreadableTest.clientTearDown(self) - - -####################################################################### -## Begin Tests - -class GeneralModuleTests(unittest.TestCase): - - def test_weakref(self): - if sys.platform[:4] == 'java': return - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - p = proxy(s) - self.assertEqual(p.fileno(), s.fileno()) - s.close() - s = None - try: - p.fileno() - except ReferenceError: - pass - else: - self.fail('Socket proxy still exists') - - def testSocketError(self): - # Testing socket module exceptions - def raise_error(*args, **kwargs): - raise socket.error - def raise_herror(*args, **kwargs): - raise socket.herror - def raise_gaierror(*args, **kwargs): - raise socket.gaierror - self.failUnlessRaises(socket.error, raise_error, - "Error raising socket exception.") - self.failUnlessRaises(socket.error, raise_herror, - "Error raising socket exception.") - self.failUnlessRaises(socket.error, raise_gaierror, - "Error raising socket exception.") - - def testCrucialConstants(self): - # Testing for mission critical constants - socket.AF_INET - socket.SOCK_STREAM - socket.SOCK_DGRAM - socket.SOCK_RAW - socket.SOCK_RDM - socket.SOCK_SEQPACKET - socket.SOL_SOCKET - socket.SO_REUSEADDR - - def testHostnameRes(self): - # Testing hostname resolution mechanisms - hostname = socket.gethostname() - try: - ip = socket.gethostbyname(hostname) - except socket.error: - # Probably name lookup wasn't set up right; skip this test - self.fail("Probably name lookup wasn't set up right; skip testHostnameRes.gethostbyname") - return - self.assert_(ip.find('.') >= 0, "Error resolving host to ip.") - try: - hname, aliases, ipaddrs = socket.gethostbyaddr(ip) - except socket.error: - # Probably a similar problem as above; skip this test - self.fail("Probably name lookup wasn't set up right; skip testHostnameRes.gethostbyaddr") - return - all_host_names = [hostname, hname] + aliases - fqhn = socket.getfqdn() - if not fqhn in all_host_names: - self.fail("Error testing host resolution mechanisms.") - - def testRefCountGetNameInfo(self): - # Testing reference count for getnameinfo - import sys - if hasattr(sys, "getrefcount"): - try: - # On some versions, this loses a reference - orig = sys.getrefcount(__name__) - socket.getnameinfo(__name__,0) - except SystemError: - if sys.getrefcount(__name__) <> orig: - self.fail("socket.getnameinfo loses a reference") - - def testInterpreterCrash(self): - if sys.platform[:4] == 'java': return - # Making sure getnameinfo doesn't crash the interpreter - try: - # On some versions, this crashes the interpreter. - socket.getnameinfo(('x', 0, 0, 0), 0) - except socket.error: - pass - -# Need to implement binary AND for ints and longs - - def testNtoH(self): - if sys.platform[:4] == 'java': return # problems with int & long - # This just checks that htons etc. are their own inverse, - # when looking at the lower 16 or 32 bits. - sizes = {socket.htonl: 32, socket.ntohl: 32, - socket.htons: 16, socket.ntohs: 16} - for func, size in sizes.items(): - mask = (1L<<size) - 1 - for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210): - self.assertEqual(i & mask, func(func(i&mask)) & mask) - - swapped = func(mask) - self.assertEqual(swapped & mask, mask) - self.assertRaises(OverflowError, func, 1L<<34) - - def testGetServBy(self): - if sys.platform[:4] == 'java': return # not implemented on java - eq = self.assertEqual - # Find one service that exists, then check all the related interfaces. - # I've ordered this by protocols that have both a tcp and udp - # protocol, at least for modern Linuxes. - if sys.platform in ('linux2', 'freebsd4', 'freebsd5', 'freebsd6', - 'darwin'): - # avoid the 'echo' service on this platform, as there is an - # assumption breaking non-standard port/protocol entry - services = ('daytime', 'qotd', 'domain') - else: - services = ('echo', 'daytime', 'domain') - for service in services: - try: - port = socket.getservbyname(service, 'tcp') - break - except socket.error: - pass - else: - raise socket.error - # Try same call with optional protocol omitted - port2 = socket.getservbyname(service) - eq(port, port2) - # Try udp, but don't barf it it doesn't exist - try: - udpport = socket.getservbyname(service, 'udp') - except socket.error: - udpport = None - else: - eq(udpport, port) - # Now make sure the lookup by port returns the same service name - eq(socket.getservbyport(port2), service) - eq(socket.getservbyport(port, 'tcp'), service) - if udpport is not None: - eq(socket.getservbyport(udpport, 'udp'), service) - - def testDefaultTimeout(self): - # Testing default timeout - # The default timeout should initially be None - self.assertEqual(socket.getdefaulttimeout(), None) - s = socket.socket() - self.assertEqual(s.gettimeout(), None) - s.close() - - # Set the default timeout to 10, and see if it propagates - socket.setdefaulttimeout(10) - self.assertEqual(socket.getdefaulttimeout(), 10) - s = socket.socket() - self.assertEqual(s.gettimeout(), 10) - s.close() - - # Reset the default timeout to None, and see if it propagates - socket.setdefaulttimeout(None) - self.assertEqual(socket.getdefaulttimeout(), None) - s = socket.socket() - self.assertEqual(s.gettimeout(), None) - s.close() - - # Check that setting it to an invalid value raises ValueError - self.assertRaises(ValueError, socket.setdefaulttimeout, -1) - - # Check that setting it to an invalid type raises TypeError - self.assertRaises(TypeError, socket.setdefaulttimeout, "spam") - - def testIPv4toString(self): - if not hasattr(socket, 'inet_pton'): - return # No inet_pton() on this platform - from socket import inet_aton as f, inet_pton, AF_INET - g = lambda a: inet_pton(AF_INET, a) - - self.assertEquals('\x00\x00\x00\x00', f('0.0.0.0')) - self.assertEquals('\xff\x00\xff\x00', f('255.0.255.0')) - self.assertEquals('\xaa\xaa\xaa\xaa', f('170.170.170.170')) - self.assertEquals('\x01\x02\x03\x04', f('1.2.3.4')) - - self.assertEquals('\x00\x00\x00\x00', g('0.0.0.0')) - self.assertEquals('\xff\x00\xff\x00', g('255.0.255.0')) - self.assertEquals('\xaa\xaa\xaa\xaa', g('170.170.170.170')) - - def testIPv6toString(self): - if not hasattr(socket, 'inet_pton'): - return # No inet_pton() on this platform - try: - from socket import inet_pton, AF_INET6, has_ipv6 - if not has_ipv6: - return - except ImportError: - return - f = lambda a: inet_pton(AF_INET6, a) - - self.assertEquals('\x00' * 16, f('::')) - self.assertEquals('\x00' * 16, f('0::0')) - self.assertEquals('\x00\x01' + '\x00' * 14, f('1::')) - self.assertEquals( - '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae', - f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae') - ) - - def testStringToIPv4(self): - if not hasattr(socket, 'inet_ntop'): - return # No inet_ntop() on this platform - from socket import inet_ntoa as f, inet_ntop, AF_INET - g = lambda a: inet_ntop(AF_INET, a) - - self.assertEquals('1.0.1.0', f('\x01\x00\x01\x00')) - self.assertEquals('170.85.170.85', f('\xaa\x55\xaa\x55')) - self.assertEquals('255.255.255.255', f('\xff\xff\xff\xff')) - self.assertEquals('1.2.3.4', f('\x01\x02\x03\x04')) - - self.assertEquals('1.0.1.0', g('\x01\x00\x01\x00')) - self.assertEquals('170.85.170.85', g('\xaa\x55\xaa\x55')) - self.assertEquals('255.255.255.255', g('\xff\xff\xff\xff')) - - def testStringToIPv6(self): - if not hasattr(socket, 'inet_ntop'): - return # No inet_ntop() on this platform - try: - from socket import inet_ntop, AF_INET6, has_ipv6 - if not has_ipv6: - return - except ImportError: - return - f = lambda a: inet_ntop(AF_INET6, a) - - self.assertEquals('::', f('\x00' * 16)) - self.assertEquals('::1', f('\x00' * 15 + '\x01')) - self.assertEquals( - 'aef:b01:506:1001:ffff:9997:55:170', - f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70') - ) - - # XXX The following don't test module-level functionality... - - def testSockName(self): - # Testing getsockname() - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(("0.0.0.0", PORT+1)) - name = sock.getsockname() - self.assertEqual(name, ("0.0.0.0", PORT+1)) - - def testGetSockOpt(self): - # Testing getsockopt() - # We know a socket should start without reuse==0 - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) - self.failIf(reuse != 0, "initial mode is reuse") - - def testSetSockOpt(self): - # Testing setsockopt() - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) - self.failIf(reuse == 0, "failed to set reuse mode") - - def testSendAfterClose(self): - # testing send() after close() with timeout - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1) - sock.close() - self.assertRaises(socket.error, sock.send, "spam") - -class BasicTCPTest(SocketConnectedTest): - - def __init__(self, methodName='runTest'): - SocketConnectedTest.__init__(self, methodName=methodName) - - def testRecv(self): - # Testing large receive over TCP - msg = self.cli_conn.recv(1024) - self.assertEqual(msg, MSG) - - def _testRecv(self): - self.serv_conn.send(MSG) - - def testOverFlowRecv(self): - # Testing receive in chunks over TCP - seg1 = self.cli_conn.recv(len(MSG) - 3) - seg2 = self.cli_conn.recv(1024) - msg = seg1 + seg2 - self.assertEqual(msg, MSG) - - def _testOverFlowRecv(self): - self.serv_conn.send(MSG) - - def testRecvFrom(self): - # Testing large recvfrom() over TCP - msg, addr = self.cli_conn.recvfrom(1024) - self.assertEqual(msg, MSG) - - def _testRecvFrom(self): - self.serv_conn.send(MSG) - - def testOverFlowRecvFrom(self): - # Testing recvfrom() in chunks over TCP - seg1, addr = self.cli_conn.recvfrom(len(MSG)-3) - seg2, addr = self.cli_conn.recvfrom(1024) - msg = seg1 + seg2 - self.assertEqual(msg, MSG) - - def _testOverFlowRecvFrom(self): - self.serv_conn.send(MSG) - - def testSendAll(self): - # Testing sendall() with a 2048 byte string over TCP - msg = '' - while 1: - read = self.cli_conn.recv(1024) - if not read: - break - msg += read - self.assertEqual(msg, 'f' * 2048) - - def _testSendAll(self): - big_chunk = 'f' * 2048 - self.serv_conn.sendall(big_chunk) - - def testFromFd(self): - # Testing fromfd() - if not hasattr(socket, "fromfd"): - return # On Windows, this doesn't exist - fd = self.cli_conn.fileno() - sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) - msg = sock.recv(1024) - self.assertEqual(msg, MSG) - - def _testFromFd(self): - self.serv_conn.send(MSG) - - def testShutdown(self): - # Testing shutdown() - msg = self.cli_conn.recv(1024) - self.assertEqual(msg, MSG) - - def _testShutdown(self): - self.serv_conn.send(MSG) - self.serv_conn.shutdown(2) - -class BasicUDPTest(ThreadedUDPSocketTest): - - def __init__(self, methodName='runTest'): - ThreadedUDPSocketTest.__init__(self, methodName=methodName) - - def testSendtoAndRecv(self): - # Testing sendto() and Recv() over UDP - msg = self.serv.recv(len(MSG)) - self.assertEqual(msg, MSG) - - def _testSendtoAndRecv(self): - self.cli.sendto(MSG, 0, (HOST, PORT)) - - def testRecvFrom(self): - # Testing recvfrom() over UDP - msg, addr = self.serv.recvfrom(len(MSG)) - self.assertEqual(msg, MSG) - - def _testRecvFrom(self): - self.cli.sendto(MSG, 0, (HOST, PORT)) - - def testSendtoEightBitSafe(self): - # This test is necessary because java only supports signed bytes - msg = self.serv.recv(len(EIGHT_BIT_MSG)) - self.assertEqual(msg, EIGHT_BIT_MSG) - - def _testSendtoEightBitSafe(self): - self.cli.sendto(EIGHT_BIT_MSG, 0, (HOST, PORT)) - -class BasicSocketPairTest(SocketPairTest): - - def __init__(self, methodName='runTest'): - SocketPairTest.__init__(self, methodName=methodName) - - def testRecv(self): - msg = self.serv.recv(1024) - self.assertEqual(msg, MSG) - - def _testRecv(self): - self.cli.send(MSG) - - def testSend(self): - self.serv.send(MSG) - - def _testSend(self): - msg = self.cli.recv(1024) - self.assertEqual(msg, MSG) - -class NonBlockingTCPTests(ThreadedTCPSocketTest): - - def __init__(self, methodName='runTest'): - ThreadedTCPSocketTest.__init__(self, methodName=methodName) - - def testSetBlocking(self): - # Testing whether set blocking works - self.serv.setblocking(0) - start = time.time() - try: - self.serv.accept() - except socket.error: - pass - end = time.time() - self.assert_((end - start) < 1.0, "Error setting non-blocking mode.") - - def _testSetBlocking(self): - pass - - # - # AMAK: 20070307 - # Split testAccept into two separate tests - # 1. A test for non-blocking accept when there is NO connection pending - # 2. A test for non-blocking accept when there is A connection pending - # I think that perhaps the only reason the original combined test passes - # on cpython is because of thread timing and sychronization parameters - # of that platform. - # - - def testAcceptNoConnection(self): - # Testing non-blocking accept returns immediately when no connection - self.serv.setblocking(0) - try: - conn, addr = self.serv.accept() - except socket.error: - pass - else: - self.fail("Error trying to do non-blocking accept.") - - def _testAcceptNoConnection(self): - # Client side does nothing - pass - - def testAcceptConnection(self): - # Testing non-blocking accept works when connection present - self.serv.setblocking(0) - read, write, err = select.select([self.serv], [], []) - if self.serv in read: - conn, addr = self.serv.accept() - else: - self.fail("Error trying to do accept after select: server socket was not in 'read'able list") - - def _testAcceptConnection(self): - # Make a connection to the server - self.cli.connect((HOST, PORT)) - - # - # AMAK: 20070311 - # Introduced a new test for non-blocking connect - # Renamed old testConnect to testBlockingConnect - # - - def testBlockingConnect(self): - # Testing blocking connect - conn, addr = self.serv.accept() - - def _testBlockingConnect(self): - # Testing blocking connect - self.cli.settimeout(10) - self.cli.connect((HOST, PORT)) - - def testNonBlockingConnect(self): - # Testing non-blocking connect - conn, addr = self.serv.accept() - - def _testNonBlockingConnect(self): - # Testing non-blocking connect - self.cli.setblocking(0) - result = self.cli.connect_ex((HOST, PORT)) - rfds, wfds, xfds = select.select([], [self.cli], []) - self.failUnless(self.cli in wfds) - try: - self.cli.send(MSG) - except socket.error: - self.fail("Sending on connected socket should not have raised socket.error") - - # - # AMAK: 20070518 - # Introduced a new test for connect with bind to specific local address - # - - def testConnectWithLocalBind(self): - # Test blocking connect - conn, addr = self.serv.accept() - - def _testConnectWithLocalBind(self): - # Testing blocking connect with local bind - self.cli.settimeout(1) - self.cli.bind( (HOST, PORT-1) ) - self.cli.connect((HOST, PORT)) - bound_host, bound_port = self.cli.getsockname() - self.failUnlessEqual(bound_port, PORT-1) - - def testRecvData(self): - # Testing non-blocking recv - conn, addr = self.serv.accept() - conn.setblocking(0) - rfds, wfds, xfds = select.select([conn], [], []) - if conn in rfds: - msg = conn.recv(len(MSG)) - self.assertEqual(msg, MSG) - else: - self.fail("Non-blocking socket with data should been in read list.") - - def _testRecvData(self): - self.cli.connect((HOST, PORT)) - self.cli.send(MSG) - - def testRecvNoData(self): - # Testing non-blocking recv - conn, addr = self.serv.accept() - conn.setblocking(0) - try: - msg = conn.recv(len(MSG)) - except socket.error: - pass - else: - self.fail("Non-blocking recv of no data should have raised socket.error.") - - def _testRecvNoData(self): - self.cli.connect((HOST, PORT)) - time.sleep(0.1) - -class NonBlockingUDPTests(ThreadedUDPSocketTest): pass - -# -# TODO: Write some non-blocking UDP tests -# - -class FileObjectClassTestCase(SocketConnectedTest): - - bufsize = -1 # Use default buffer size - - def __init__(self, methodName='runTest'): - SocketConnectedTest.__init__(self, methodName=methodName) - - def setUp(self): - SocketConnectedTest.setUp(self) - self.serv_file = self.cli_conn.makefile('rb', self.bufsize) - - def tearDown(self): - self.serv_file.close() - self.assert_(self.serv_file.closed) - self.serv_file = None - SocketConnectedTest.tearDown(self) - - def clientSetUp(self): - SocketConnectedTest.clientSetUp(self) - self.cli_file = self.serv_conn.makefile('wb') - - def clientTearDown(self): - self.cli_file.close() - self.assert_(self.cli_file.closed) - self.cli_file = None - SocketConnectedTest.clientTearDown(self) - - def testSmallRead(self): - # Performing small file read test - first_seg = self.serv_file.read(len(MSG)-3) - second_seg = self.serv_file.read(3) - msg = first_seg + second_seg - self.assertEqual(msg, MSG) - - def _testSmallRead(self): - self.cli_file.write(MSG) - self.cli_file.flush() - - def testFullRead(self): - # read until EOF - msg = self.serv_file.read() - self.assertEqual(msg, MSG) - - def _testFullRead(self): - self.cli_file.write(MSG) - self.cli_file.close() - - def testUnbufferedRead(self): - # Performing unbuffered file read test - buf = '' - while 1: - char = self.serv_file.read(1) - if not char: - break - buf += char - self.assertEqual(buf, MSG) - - def _testUnbufferedRead(self): - self.cli_file.write(MSG) - self.cli_file.flush() - - def testReadline(self): - # Performing file readline test - line = self.serv_file.readline() - self.assertEqual(line, MSG) - - def _testReadline(self): - self.cli_file.write(MSG) - self.cli_file.flush() - - def testClosedAttr(self): - self.assert_(not self.serv_file.closed) - - def _testClosedAttr(self): - self.assert_(not self.cli_file.closed) - -class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): - - """Repeat the tests from FileObjectClassTestCase with bufsize==0. - - In this case (and in this case only), it should be possible to - create a file object, read a line from it, create another file - object, read another line from it, without loss of data in the - first file object's buffer. Note that httplib relies on this - when reading multiple requests from the same socket.""" - - bufsize = 0 # Use unbuffered mode - - def testUnbufferedReadline(self): - # Read a line, create a new file object, read another line with it - line = self.serv_file.readline() # first line - self.assertEqual(line, "A. " + MSG) # first line - self.serv_file = self.cli_conn.makefile('rb', 0) - line = self.serv_file.readline() # second line - self.assertEqual(line, "B. " + MSG) # second line - - def _testUnbufferedReadline(self): - self.cli_file.write("A. " + MSG) - self.cli_file.write("B. " + MSG) - self.cli_file.flush() - -class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase): - - bufsize = 1 # Default-buffered for reading; line-buffered for writing - - -class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase): - - bufsize = 2 # Exercise the buffering code - -class TCPTimeoutTest(SocketTCPTest): - - def testTCPTimeout(self): - def raise_timeout(*args, **kwargs): - self.serv.settimeout(1.0) - self.serv.accept() - self.failUnlessRaises(socket.timeout, raise_timeout, - "Error generating a timeout exception (TCP)") - - def testTimeoutZero(self): - ok = False - try: - self.serv.settimeout(0.0) - foo = self.serv.accept() - except socket.timeout: - self.fail("caught timeout instead of error (TCP)") - except socket.error: - ok = True - except Exception, x: - self.fail("caught unexpected exception (TCP): %s" % str(x)) - if not ok: - self.fail("accept() returned success when we did not expect it") - -class TCPClientTimeoutTest(ThreadedTCPSocketTest): - - def testTCPClientTimeout(self): - pass # i.e. do not accept - - def _testTCPClientTimeout(self): - try: - self.cli.settimeout(0.1) - self.cli.connect( (HOST, PORT) ) - except socket.timeout, st: - pass - except Exception, x: - self.fail("Client socket timeout should have raised socket.timeout, not %s" % str(x)) - else: - self.fail("Client socket timeout should have raised socket.timeout") - -# -# AMAK: 20070307 -# Corrected the superclass of UDPTimeoutTest -# - -class UDPTimeoutTest(SocketUDPTest): - - def testUDPTimeout(self): - def raise_timeout(*args, **kwargs): - self.serv.settimeout(1.0) - self.serv.recv(1024) - self.failUnlessRaises(socket.timeout, raise_timeout, - "Error generating a timeout exception (UDP)") - - def testTimeoutZero(self): - ok = False - try: - self.serv.settimeout(0.0) - foo = self.serv.recv(1024) - except socket.timeout: - self.fail("caught timeout instead of error (UDP)") - except socket.error: - ok = True - except Exception, x: - self.fail("caught unexpected exception (UDP): %s" % str(x)) - if not ok: - self.fail("recv() returned success when we did not expect it") - -class TestExceptions(unittest.TestCase): - - def testExceptionTree(self): - self.assert_(issubclass(socket.error, Exception)) - self.assert_(issubclass(socket.herror, socket.error)) - self.assert_(issubclass(socket.gaierror, socket.error)) - self.assert_(issubclass(socket.timeout, socket.error)) - -class TestJythonExceptions(unittest.TestCase): - - def testHostNotFound(self): - try: - socket.gethostbyname("doesnotexist") - except socket.gaierror, gaix: - self.failUnlessEqual(gaix[0], errno.EGETADDRINFOFAILED) - except Exception, x: - self.fail("Get host name for non-existent host raised wrong exception: %s" % x) - - def testConnectionRefused(self): - try: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # This port should not be open at this time - s.connect( (HOST, PORT) ) - except socket.error, se: - self.failUnlessEqual(se[0], errno.ECONNREFUSED) - except Exception, x: - self.fail("Connection to non-existent host/port raised wrong exception: %s" % x) - else: - self.fail("Socket (%s,%s) should not have been listening at this time" % (HOST, PORT)) - - def testBindException(self): - # First bind to the target port - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind( (HOST, PORT) ) - s.listen() - try: - try: - # And then try to bind again - t = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - t.bind( (HOST, PORT) ) - t.listen() - except socket.error, se: - self.failUnlessEqual(se[0], errno.EACCES) - except Exception, x: - self.fail("Binding to already bound host/port raised wrong exception: %s" % x) - else: - self.fail("Binding to already bound host/port should have raised exception") - finally: - s.close() - - def testUnresolvedAddress(self): - try: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.connect( ('non.existent.server', PORT) ) - except socket.gaierror, gaix: - self.failUnlessEqual(gaix[0], errno.EGETADDRINFOFAILED) - except Exception, x: - self.fail("Get host name for non-existent host raised wrong exception: %s" % x) - else: - self.fail("Get host name for non-existent host should have raised exception") - -class TestAddressParameters: - - def testBindNonTupleEndpointRaisesTypeError(self): - try: - self.socket.bind(HOST, PORT) - except TypeError: - pass - else: - self.fail("Illegal non-tuple bind address did not raise TypeError") - - def testConnectNonTupleEndpointRaisesTypeError(self): - try: - self.socket.connect(HOST, PORT) - except TypeError: - pass - else: - self.fail("Illegal non-tuple connect address did not raise TypeError") - - def testConnectExNonTupleEndpointRaisesTypeError(self): - try: - self.socket.connect_ex(HOST, PORT) - except TypeError: - pass - else: - self.fail("Illegal non-tuple connect address did not raise TypeError") - -class TestTCPAddressParameters(unittest.TestCase, TestAddressParameters): - - def setUp(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - -class TestUDPAddressParameters(unittest.TestCase, TestAddressParameters): - - def setUp(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - -def test_main(): - tests = [ - GeneralModuleTests, - BasicTCPTest, - TCPTimeoutTest, - TCPClientTimeoutTest, - TestExceptions, - TestTCPAddressParameters, - TestUDPAddressParameters, - BasicUDPTest, - UDPTimeoutTest, - NonBlockingTCPTests, - NonBlockingUDPTests, - FileObjectClassTestCase, - UnbufferedFileObjectClassTestCase, - LineBufferedFileObjectClassTestCase, - SmallBufferedFileObjectCla... [truncated message content] |
From: <am...@us...> - 2007-07-19 18:27:26
|
Revision: 3332 http://svn.sourceforge.net/jython/?rev=3332&view=rev Author: amak Date: 2007-07-19 11:27:23 -0700 (Thu, 19 Jul 2007) Log Message: ----------- . Added a socket.getblocking() method, and unittest for same . Tidied up select exception handling . Removed ERRNO_ constants from select module; jython 2.1 users can create their own errno module. . Some small speed improvements to select.select . Added a cpython compatible select function to select module Modified Paths: -------------- trunk/jython/Lib/select.py trunk/jython/Lib/socket.py trunk/jython/Lib/test/test_select.py trunk/jython/Lib/test/test_socket.py Modified: trunk/jython/Lib/select.py =================================================================== --- trunk/jython/Lib/select.py 2007-07-19 07:19:30 UTC (rev 3331) +++ trunk/jython/Lib/select.py 2007-07-19 18:27:23 UTC (rev 3332) @@ -9,15 +9,7 @@ import socket -try: - import errno - ERRNO_EINVAL = errno.EINVAL - ERRNO_ENOTSOCK = errno.ENOTSOCK - ERRNO_ESOCKISBLOCKING = errno.ESOCKISBLOCKING -except ImportError: - ERRNO_EINVAL = 22 - ERRNO_ENOTSOCK = 88 - ERRNO_ESOCKISBLOCKING = 20000 +import errno class error(Exception): pass @@ -27,14 +19,16 @@ # (<javaexception>, <circumstance>) : lambda: <code that raises the python equivalent> -(java.nio.channels.IllegalBlockingModeException, ALL) : lambda exc: error(ERRNO_ESOCKISBLOCKING, 'socket must be in non-blocking mode'), +(java.nio.channels.IllegalBlockingModeException, ALL) : error(errno.ESOCKISBLOCKING, 'socket must be in non-blocking mode'), } def _map_exception(exc, circumstance=ALL): try: - return _exception_map[(exc.__class__, circumstance)](exc) + mapped_exception = _exception_map[(exc.__class__, circumstance)] + mapped_exception.java_exception = exc + return mapped_exception except KeyError: - return error(-1, 'Unmapped java exception: %s' % exc.toString()) + return error(-1, 'Unmapped java exception: <%s:%s>' % (exc.toString(), circumstance)) POLLIN = 1 POLLOUT = 2 @@ -48,6 +42,17 @@ POLLHUP = 16 POLLNVAL = 32 +def _getselectable(selectable_object): + for method in ['getchannel', 'fileno']: + try: + channel = getattr(selectable_object, method)() + if channel and not isinstance(channel, java.nio.channels.SelectableChannel): + raise TypeError("Object '%s' is not watchable" % selectable_object, errno.ENOTSOCK) + return channel + except: + pass + raise TypeError("Object '%s' is not watchable" % selectable_object, errno.ENOTSOCK) + class poll: def __init__(self): @@ -55,15 +60,6 @@ self.chanmap = {} self.unconnected_sockets = [] - def _getselectable(self, socket_object): - for st in socket.SocketTypes: - if isinstance(socket_object, st): - try: - return socket_object.getchannel() - except: - return None - raise error("Object '%s' is not watchable" % socket_object, ERRNO_ENOTSOCK) - def _register_channel(self, socket_object, channel, mask): jmask = 0 if mask & POLLIN: @@ -82,7 +78,7 @@ def _check_unconnected_sockets(self): temp_list = [] for socket_object, mask in self.unconnected_sockets: - channel = self._getselectable(socket_object) + channel = _getselectable(socket_object) if channel is not None: self._register_channel(socket_object, channel, mask) else: @@ -91,7 +87,7 @@ def register(self, socket_object, mask = POLLIN|POLLOUT|POLLPRI): try: - channel = self._getselectable(socket_object) + channel = _getselectable(socket_object) if channel is None: # The socket is not yet connected, and thus has no channel # Add it to a pending list, and return @@ -103,7 +99,7 @@ def unregister(self, socket_object): try: - channel = self._getselectable(socket_object) + channel = _getselectable(socket_object) self.chanmap[channel][1].cancel() del self.chanmap[channel] except java.lang.Exception, jlx: @@ -121,7 +117,7 @@ # No multiplication required: both cpython and java use millisecond timeouts self.selector.select(timeout) except ValueError, vx: - raise error("poll timeout must be a number of milliseconds or None", ERRNO_EINVAL) + raise error("poll timeout must be a number of milliseconds or None", errno.EINVAL) # The returned selectedKeys cannot be used from multiple threads! return self.selector.selectedKeys() @@ -159,41 +155,57 @@ except Exception, x: raise TypeError("Select timeout value must be a number or None") if value < 0: - raise error("Select timeout value cannot be negative", ERRNO_EINVAL) + raise error("Select timeout value cannot be negative", errno.EINVAL) if floatvalue < 0.000001: return 0 return int(floatvalue * 1000) # Convert to milliseconds -def select ( read_fd_list, write_fd_list, outofband_fd_list, timeout=None): +def native_select(read_fd_list, write_fd_list, outofband_fd_list, timeout=None): timeout = _calcselecttimeoutvalue(timeout) # First create a poll object to do the actual watching. pobj = poll() - already_registered = {} - # Check the read list try: - # AMAK: Need to remove all this list searching, change to a dictionary? + registered_for_read = {} + # Check the read list for fd in read_fd_list: - mask = POLLIN - if fd in write_fd_list: - mask |= POLLOUT - pobj.register(fd, mask) - already_registered[fd] = 1 + pobj.register(fd, POLLIN) + registered_for_read[fd] = 1 # And now the write list for fd in write_fd_list: - if not already_registered.has_key(fd): + if registered_for_read.has_key(fd): + # registering a second time overwrites the first + pobj.register(fd, POLLIN|POLLOUT) + else: pobj.register(fd, POLLOUT) results = pobj.poll(timeout) - except AttributeError, ax: - if str(ax) == "__getitem__": - raise TypeError(ax) - raise ax - # Now start preparing the results - read_ready_list, write_ready_list, oob_ready_list = [], [], [] - for fd, mask in results: - if mask & POLLIN: - read_ready_list.append(fd) - if mask & POLLOUT: - write_ready_list.append(fd) - pobj.close() - return read_ready_list, write_ready_list, oob_ready_list + # Now start preparing the results + read_ready_list, write_ready_list, oob_ready_list = [], [], [] + for fd, mask in results: + if mask & POLLIN: + read_ready_list.append(fd) + if mask & POLLOUT: + write_ready_list.append(fd) + return read_ready_list, write_ready_list, oob_ready_list + finally: + # Need to close the poll object no matter what happened + # If it is left open, it may still have references to sockets + # That were registered before any exceptions occurred + pobj.close() +select = native_select + +def cpython_compatible_select(read_fd_list, write_fd_list, outofband_fd_list, timeout=None): + # First turn all sockets to non-blocking + # keeping track of which ones have changed + modified_channels = [] + try: + for socket_list in [read_fd_list, write_fd_list, outofband_fd_list]: + for s in socket_list: + channel = _getselectable(s) + if channel.isBlocking(): + modified_channels.append(channel) + channel.configureBlocking(0) + return native_select(read_fd_list, write_fd_list, outofband_fd_list, timeout) + finally: + for channel in modified_channels: + channel.configureBlocking(1) Modified: trunk/jython/Lib/socket.py =================================================================== --- trunk/jython/Lib/socket.py 2007-07-19 07:19:30 UTC (rev 3331) +++ trunk/jython/Lib/socket.py 2007-07-19 18:27:23 UTC (rev 3332) @@ -432,6 +432,9 @@ self.timeout = 0.0 self._config() + def getblocking(self): + return self.mode == MODE_BLOCKING + def _config(self): assert self.mode in _permitted_modes if self.sock_impl: self.sock_impl.config(self.mode, self.timeout) Modified: trunk/jython/Lib/test/test_select.py =================================================================== --- trunk/jython/Lib/test/test_select.py 2007-07-19 07:19:30 UTC (rev 3331) +++ trunk/jython/Lib/test/test_select.py 2007-07-19 18:27:23 UTC (rev 3332) @@ -61,8 +61,10 @@ try: timeout = 0 # Can't wait forever rfd, wfd, xfd = select.select(args[0], args[1], args[2], timeout) - except TypeError: + except (select.error, TypeError): pass + except Exception, x: + self.fail("Selecting on '%s' raised wrong exception %s" % (str(bad_select_set), str(x))) else: self.fail("Selecting on '%s' should have raised TypeError" % str(bad_select_set)) Modified: trunk/jython/Lib/test/test_socket.py =================================================================== --- trunk/jython/Lib/test/test_socket.py 2007-07-19 07:19:30 UTC (rev 3331) +++ trunk/jython/Lib/test/test_socket.py 2007-07-19 18:27:23 UTC (rev 3332) @@ -648,6 +648,16 @@ def _testSetBlocking(self): pass + def testGetBlocking(self): + # Testing whether set blocking works + self.serv.setblocking(0) + self.failUnless(not self.serv.getblocking(), "Getblocking return true instead of false") + self.serv.setblocking(1) + self.failUnless(self.serv.getblocking(), "Getblocking return false instead of true") + + def _testGetBlocking(self): + pass + # # AMAK: 20070307 # Split testAccept into two separate tests This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <pj...@us...> - 2007-11-23 00:36:33
|
Revision: 3712 http://jython.svn.sourceforge.net/jython/?rev=3712&view=rev Author: pjenvey Date: 2007-11-22 16:36:30 -0800 (Thu, 22 Nov 2007) Log Message: ----------- update tempfile and its test to CPython 2.5.1's version with a couple modifications for Jython. allows test_pkg and test_shutil to pass fixes #1755344, #1783692 Modified Paths: -------------- trunk/jython/Lib/tempfile.py trunk/jython/Lib/test/regrtest.py Added Paths: ----------- trunk/jython/Lib/test/test_tempfile.py Modified: trunk/jython/Lib/tempfile.py =================================================================== --- trunk/jython/Lib/tempfile.py 2007-11-23 00:02:50 UTC (rev 3711) +++ trunk/jython/Lib/tempfile.py 2007-11-23 00:36:30 UTC (rev 3712) @@ -1,188 +1,394 @@ -# XXX added to fix jython specific problem. Should be removed when real -# problem is fixed. -import java.io.File -"""Temporary files and filenames.""" +"""Temporary files. -# XXX This tries to be not UNIX specific, but I don't know beans about -# how to choose a temp directory or filename on MS-DOS or other -# systems so it may have to be changed... +This module provides generic, low- and high-level interfaces for +creating temporary files and directories. The interfaces listed +as "safe" just below can be used without fear of race conditions. +Those listed as "unsafe" cannot, and are provided for backward +compatibility only. -import os +This module also provides some data items to the user: -__all__ = ["mktemp", "TemporaryFile", "tempdir", "gettempprefix"] + TMP_MAX - maximum number of names that will be tried before + giving up. + template - the default prefix for all temporary names. + You may change this to control the default prefix. + tempdir - If this is set to a string before the first use of + any routine from this module, it will be considered as + another candidate location to store temporary files. +""" -# Parameters that the caller may set to override the defaults +__all__ = [ + "NamedTemporaryFile", "TemporaryFile", # high level safe interfaces + "mkstemp", "mkdtemp", # low level safe interfaces + "mktemp", # deprecated unsafe interface + "TMP_MAX", "gettempprefix", # constants + "tempdir", "gettempdir" + ] + + +# Imports. + +import os as _os +import errno as _errno +from random import Random as _Random + +if _os.name == 'mac': + import Carbon.Folder as _Folder + import Carbon.Folders as _Folders + +try: + import fcntl as _fcntl +except ImportError: + def _set_cloexec(fd): + pass +else: + def _set_cloexec(fd): + try: + flags = _fcntl.fcntl(fd, _fcntl.F_GETFD, 0) + except IOError: + pass + else: + # flags read successfully, modify + flags |= _fcntl.FD_CLOEXEC + _fcntl.fcntl(fd, _fcntl.F_SETFD, flags) + + +try: + import thread as _thread +except ImportError: + import dummy_thread as _thread +_allocate_lock = _thread.allocate_lock + +_text_openflags = _os.O_RDWR | _os.O_CREAT | _os.O_EXCL +if hasattr(_os, 'O_NOINHERIT'): + _text_openflags |= _os.O_NOINHERIT +if hasattr(_os, 'O_NOFOLLOW'): + _text_openflags |= _os.O_NOFOLLOW + +_bin_openflags = _text_openflags +if hasattr(_os, 'O_BINARY'): + _bin_openflags |= _os.O_BINARY + +if hasattr(_os, 'TMP_MAX'): + TMP_MAX = _os.TMP_MAX +else: + TMP_MAX = 10000 + +template = "tmp" + tempdir = None -template = None -def gettempdir(): - """Function to calculate the directory to use.""" - global tempdir - if tempdir is not None: - return tempdir +# Internal routines. - # _gettempdir_inner deduces whether a candidate temp dir is usable by - # trying to create a file in it, and write to it. If that succeeds, - # great, it closes the file and unlinks it. There's a race, though: - # the *name* of the test file it tries is the same across all threads - # under most OSes (Linux is an exception), and letting multiple threads - # all try to open, write to, close, and unlink a single file can cause - # a variety of bogus errors (e.g., you cannot unlink a file under - # Windows if anyone has it open, and two threads cannot create the - # same file in O_EXCL mode under Unix). The simplest cure is to serialize - # calls to _gettempdir_inner. This isn't a real expense, because the - # first thread to succeed sets the global tempdir, and all subsequent - # calls to gettempdir() reuse that without trying _gettempdir_inner. - _tempdir_lock.acquire() +_once_lock = _allocate_lock() + +if hasattr(_os, "lstat"): + _stat = _os.lstat +elif hasattr(_os, "stat"): + _stat = _os.stat +else: + # Fallback. All we need is something that raises os.error if the + # file doesn't exist. + def _stat(fn): + try: + f = open(fn) + except IOError: + raise _os.error + f.close() + +def _exists(fn): try: - return _gettempdir_inner() - finally: - _tempdir_lock.release() + _stat(fn) + except _os.error: + return False + else: + return True -def _gettempdir_inner(): - """Function to calculate the directory to use.""" - global tempdir - if tempdir is not None: - return tempdir - try: - pwd = os.getcwd() - except (AttributeError, os.error): - pwd = os.curdir - attempdirs = ['/tmp', '/var/tmp', '/usr/tmp', pwd] - if os.name == 'nt': - attempdirs.insert(0, 'C:\\TEMP') - attempdirs.insert(0, '\\TEMP') - elif os.name == 'mac': - import macfs, MACFS +class _RandomNameSequence: + """An instance of _RandomNameSequence generates an endless + sequence of unpredictable strings which can safely be incorporated + into file names. Each string is six characters long. Multiple + threads can safely use the same instance at the same time. + + _RandomNameSequence is an iterator.""" + + characters = ("abcdefghijklmnopqrstuvwxyz" + + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + "0123456789-_") + + def __init__(self): + self.mutex = _allocate_lock() + self.rng = _Random() + self.normcase = _os.path.normcase + + def __iter__(self): + return self + + def next(self): + m = self.mutex + c = self.characters + choose = self.rng.choice + + m.acquire() try: - refnum, dirid = macfs.FindFolder(MACFS.kOnSystemDisk, - MACFS.kTemporaryFolderType, 1) - dirname = macfs.FSSpec((refnum, dirid, '')).as_pathname() - attempdirs.insert(0, dirname) - except macfs.error: - pass - elif os.name == 'riscos': - scrapdir = os.getenv('Wimp$ScrapDir') - if scrapdir: - attempdirs.insert(0, scrapdir) + letters = [choose(c) for dummy in "123456"] + finally: + m.release() + + return self.normcase(''.join(letters)) + +def _candidate_tempdir_list(): + """Generate a list of candidate temporary directories which + _get_default_tempdir will try.""" + + dirlist = [] + + # First, try the environment. for envname in 'TMPDIR', 'TEMP', 'TMP': - if os.environ.has_key(envname): - attempdirs.insert(0, os.environ[envname]) - testfile = gettempprefix() + 'test' - for dir in attempdirs: + dirname = _os.getenv(envname) + if dirname: dirlist.append(dirname) + + # Failing that, try OS-specific locations. + if _os.name == 'mac': try: - filename = os.path.join(dir, testfile) - if os.name == 'posix': - try: - fd = os.open(filename, - os.O_RDWR | os.O_CREAT | os.O_EXCL, 0700) - except OSError: - pass - else: - fp = os.fdopen(fd, 'w') - fp.write('blat') - fp.close() - os.unlink(filename) - del fp, fd - tempdir = dir - break - else: - fp = open(filename, 'w') + fsr = _Folder.FSFindFolder(_Folders.kOnSystemDisk, + _Folders.kTemporaryFolderType, 1) + dirname = fsr.as_pathname() + dirlist.append(dirname) + except _Folder.error: + pass + elif _os.name == 'riscos': + dirname = _os.getenv('Wimp$ScrapDir') + if dirname: dirlist.append(dirname) + elif _os.name == 'nt': + dirlist.extend([ r'c:\temp', r'c:\tmp', r'\temp', r'\tmp' ]) + else: + dirlist.extend([ '/tmp', '/var/tmp', '/usr/tmp' ]) + + # As a last resort, the current directory. + try: + dirlist.append(_os.getcwd()) + except (AttributeError, _os.error): + dirlist.append(_os.curdir) + + return dirlist + +def _get_default_tempdir(): + """Calculate the default directory to use for temporary files. + This routine should be called exactly once. + + We determine whether or not a candidate temp dir is usable by + trying to create and write to a file in that directory. If this + is successful, the test file is deleted. To prevent denial of + service, the name of the test file must be randomized.""" + + namer = _RandomNameSequence() + dirlist = _candidate_tempdir_list() + flags = _text_openflags + + for dir in dirlist: + if dir != _os.curdir: + dir = _os.path.normcase(_os.path.abspath(dir)) + # Try only a few names per directory. + for seq in xrange(100): + name = namer.next() + filename = _os.path.join(dir, name) + try: + fd = _os.open(filename, flags, 0600) + fp = _os.fdopen(fd, 'w') fp.write('blat') fp.close() - os.unlink(filename) - tempdir = dir - break - except IOError: - pass + _os.unlink(filename) + del fp, fd + return dir + except (OSError, IOError), e: + if e[0] != _errno.EEXIST: + break # no point trying more names in this directory + pass + raise IOError, (_errno.ENOENT, + ("No usable temporary directory found in %s" % dirlist)) + +_name_sequence = None + +def _get_candidate_names(): + """Common setup sequence for all user-callable interfaces.""" + + global _name_sequence + if _name_sequence is None: + _once_lock.acquire() + try: + if _name_sequence is None: + _name_sequence = _RandomNameSequence() + finally: + _once_lock.release() + return _name_sequence + + +def _mkstemp_inner(dir, pre, suf, flags): + """Code common to mkstemp, TemporaryFile, and NamedTemporaryFile.""" + + names = _get_candidate_names() + + for seq in xrange(TMP_MAX): + name = names.next() + file = _os.path.join(dir, pre + name + suf) + try: + fd = _os.open(file, flags, 0600) + _set_cloexec(fd) + return (fd, _os.path.abspath(file)) + except OSError, e: + if e.errno == _errno.EEXIST: + continue # try again + raise + + raise IOError, (_errno.EEXIST, "No usable temporary file name found") + + +# User visible interfaces. + +def gettempprefix(): + """Accessor for tempdir.template.""" + return template + +tempdir = None + +def gettempdir(): + """Accessor for tempdir.tempdir.""" + global tempdir if tempdir is None: - msg = "Can't find a usable temporary directory amongst " + `attempdirs` - raise IOError, msg + _once_lock.acquire() + try: + if tempdir is None: + tempdir = _get_default_tempdir() + finally: + _once_lock.release() return tempdir +def mkstemp(suffix="", prefix=template, dir=None, text=False): + """mkstemp([suffix, [prefix, [dir, [text]]]]) + User-callable function to create and return a unique temporary + file. The return value is a pair (fd, name) where fd is the + file descriptor returned by os.open, and name is the filename. -# template caches the result of gettempprefix, for speed, when possible. -# XXX unclear why this isn't "_template"; left it "template" for backward -# compatibility. -if os.name == "posix": - # We don't try to cache the template on posix: the pid may change on us - # between calls due to a fork, and on Linux the pid changes even for - # another thread in the same process. Since any attempt to keep the - # cache in synch would have to call os.getpid() anyway in order to make - # sure the pid hasn't changed between calls, a cache wouldn't save any - # time. In addition, a cache is difficult to keep correct with the pid - # changing willy-nilly, and earlier attempts proved buggy (races). - template = None + If 'suffix' is specified, the file name will end with that suffix, + otherwise there will be no suffix. -# Else the pid never changes, so gettempprefix always returns the same -# string. -elif os.name == "nt": - template = '~' + `os.getpid()` + '-' -elif os.name in ('mac', 'riscos'): - template = 'Python-Tmp-' -else: - template = 'tmp' # XXX might choose a better one + If 'prefix' is specified, the file name will begin with that prefix, + otherwise a default prefix is used. -def gettempprefix(): - """Function to calculate a prefix of the filename to use. + If 'dir' is specified, the file will be created in that directory, + otherwise a default directory is used. - This incorporates the current process id on systems that support such a - notion, so that concurrent processes don't generate the same prefix. + If 'text' is specified and true, the file is opened in text + mode. Else (the default) the file is opened in binary mode. On + some operating systems, this makes no difference. + + The file is readable and writable only by the creating user ID. + If the operating system uses permission bits to indicate whether a + file is executable, the file is executable by no one. The file + descriptor is not inherited by children of this process. + + Caller is responsible for deleting the file when done with it. """ - global template - if template is None: - return '@' + `os.getpid()` + '.' + if dir is None: + dir = gettempdir() + + if text: + flags = _text_openflags else: - return template + flags = _bin_openflags + return _mkstemp_inner(dir, prefix, suffix, flags) -def mktemp(suffix=""): - """User-callable function to return a unique temporary file name.""" - dir = gettempdir() - pre = gettempprefix() - while 1: - i = _counter.get_next() - file = os.path.join(dir, pre + str(i) + suffix) - if not os.path.exists(file): + +def mkdtemp(suffix="", prefix=template, dir=None): + """mkdtemp([suffix, [prefix, [dir]]]) + User-callable function to create and return a unique temporary + directory. The return value is the pathname of the directory. + + Arguments are as for mkstemp, except that the 'text' argument is + not accepted. + + The directory is readable, writable, and searchable only by the + creating user. + + Caller is responsible for deleting the directory when done with it. + """ + + if dir is None: + dir = gettempdir() + + names = _get_candidate_names() + + for seq in xrange(TMP_MAX): + name = names.next() + file = _os.path.join(dir, prefix + name + suffix) + try: + _os.mkdir(file, 0700) return file + except OSError, e: + if e.errno == _errno.EEXIST: + continue # try again + raise + raise IOError, (_errno.EEXIST, "No usable temporary directory name found") -class TemporaryFileWrapper: - """Temporary file wrapper +def mktemp(suffix="", prefix=template, dir=None): + """mktemp([suffix, [prefix, [dir]]]) + User-callable function to return a unique temporary file name. The + file is not created. - This class provides a wrapper around files opened for temporary use. - In particular, it seeks to automatically remove the file when it is - no longer needed. + Arguments are as for mkstemp, except that the 'text' argument is + not accepted. + + This function is unsafe and should not be used. The file name + refers to a file that did not exist at some point, but by the time + you get around to creating it, someone else may have beaten you to + the punch. """ - # Cache the unlinker so we don't get spurious errors at shutdown - # when the module-level "os" is None'd out. Note that this must - # be referenced as self.unlink, because the name TemporaryFileWrapper - # may also get None'd out before __del__ is called. +## from warnings import warn as _warn +## _warn("mktemp is a potential security risk to your program", +## RuntimeWarning, stacklevel=2) - # XXX: unlink = os.unlink does not work in jython, really that should be fixed and - # the original python class could be used. - if os.name == "java": - def unlink(self, path): - if not java.io.File(path).delete(): - raise OSError(0, "couldn't delete file", path) - else: - unlink = os.unlink + if dir is None: + dir = gettempdir() - def __init__(self, file, path): + names = _get_candidate_names() + for seq in xrange(TMP_MAX): + name = names.next() + file = _os.path.join(dir, prefix + name + suffix) + if not _exists(file): + return file + + raise IOError, (_errno.EEXIST, "No usable temporary filename found") + +class _TemporaryFileWrapper: + """Temporary file wrapper + + This class provides a wrapper around files opened for + temporary use. In particular, it seeks to automatically + remove the file when it is no longer needed. + """ + + def __init__(self, file, name): self.file = file - self.path = path - self.close_called = 0 + self.name = name + self.close_called = False - def close(self): - if not self.close_called: - self.close_called = 1 - self.file.close() - self.unlink(self.path) + # XXX: CPython assigns unlink as a class var but this would + # rebind Jython's os.unlink (to be a classmethod) because it's + # not a built-in function (unfortunately built-in functions act + # differently when binding: + # http://mail.python.org/pipermail/python-dev/2003-April/034749.html) - def __del__(self): - self.close() + # Cache the unlinker so we don't get spurious errors at + # shutdown when the module-level "os" is None'd out. Note + # that this must be referenced as self.unlink, because the + # name TemporaryFileWrapper may also get None'd out before + # __del__ is called. + self.unlink = _os.unlink def __getattr__(self, name): file = self.__dict__['file'] @@ -191,65 +397,82 @@ setattr(self, name, a) return a + # NT provides delete-on-close as a primitive, so we don't need + # the wrapper to do anything special. We still use it so that + # file.name is useful (i.e. not "(fdopen)") with NamedTemporaryFile. + if _os.name != 'nt': -def TemporaryFile(mode='w+b', bufsize=-1, suffix=""): - """Create and return a temporary file (opened read-write by default).""" - name = mktemp(suffix) - if os.name == 'posix': - # Unix -- be very careful - fd = os.open(name, os.O_RDWR|os.O_CREAT|os.O_EXCL, 0700) - try: - os.unlink(name) - return os.fdopen(fd, mode, bufsize) - except: - os.close(fd) - raise - else: - # Non-unix -- can't unlink file that's still open, use wrapper - file = open(name, mode, bufsize) - return TemporaryFileWrapper(file, name) + def close(self): + if not self.close_called: + self.close_called = True + self.file.close() + self.unlink(self.name) -# In order to generate unique names, mktemp() uses _counter.get_next(). -# This returns a unique integer on each call, in a threadsafe way (i.e., -# multiple threads will never see the same integer). The integer will -# usually be a Python int, but if _counter.get_next() is called often -# enough, it will become a Python long. -# Note that the only names that survive this next block of code -# are "_counter" and "_tempdir_lock". + def __del__(self): + self.close() -class _ThreadSafeCounter: - def __init__(self, mutex, initialvalue=0): - self.mutex = mutex - self.i = initialvalue +def NamedTemporaryFile(mode='w+b', bufsize=-1, suffix="", + prefix=template, dir=None): + """Create and return a temporary file. + Arguments: + 'prefix', 'suffix', 'dir' -- as for mkstemp. + 'mode' -- the mode argument to os.fdopen (default "w+b"). + 'bufsize' -- the buffer size argument to os.fdopen (default -1). + The file is created as mkstemp() would do it. - def get_next(self): - self.mutex.acquire() - result = self.i - try: - newi = result + 1 - except OverflowError: - newi = long(result) + 1 - self.i = newi - self.mutex.release() - return result + Returns an object with a file-like interface; the name of the file + is accessible as file.name. The file will be automatically deleted + when it is closed. + """ -try: - import thread + if dir is None: + dir = gettempdir() -except ImportError: - class _DummyMutex: - def acquire(self): - pass + if 'b' in mode: + flags = _bin_openflags + else: + flags = _text_openflags - release = acquire + # Setting O_TEMPORARY in the flags causes the OS to delete + # the file when it is closed. This is only supported by Windows. + if _os.name == 'nt': + flags |= _os.O_TEMPORARY - _counter = _ThreadSafeCounter(_DummyMutex()) - _tempdir_lock = _DummyMutex() - del _DummyMutex + (fd, name) = _mkstemp_inner(dir, prefix, suffix, flags) + file = _os.fdopen(fd, mode, bufsize) + return _TemporaryFileWrapper(file, name) +if _os.name != 'posix' or _os.sys.platform == 'cygwin': + # On non-POSIX and Cygwin systems, assume that we cannot unlink a file + # while it is open. + TemporaryFile = NamedTemporaryFile + else: - _counter = _ThreadSafeCounter(thread.allocate_lock()) - _tempdir_lock = thread.allocate_lock() - del thread + def TemporaryFile(mode='w+b', bufsize=-1, suffix="", + prefix=template, dir=None): + """Create and return a temporary file. + Arguments: + 'prefix', 'suffix', 'dir' -- as for mkstemp. + 'mode' -- the mode argument to os.fdopen (default "w+b"). + 'bufsize' -- the buffer size argument to os.fdopen (default -1). + The file is created as mkstemp() would do it. -del _ThreadSafeCounter + Returns an object with a file-like interface. The file has no + name, and will cease to exist when it is closed. + """ + + if dir is None: + dir = gettempdir() + + if 'b' in mode: + flags = _bin_openflags + else: + flags = _text_openflags + + (fd, name) = _mkstemp_inner(dir, prefix, suffix, flags) + try: + _os.unlink(name) + return _os.fdopen(fd, mode, bufsize) + except: + _os.close(fd) + raise Modified: trunk/jython/Lib/test/regrtest.py =================================================================== --- trunk/jython/Lib/test/regrtest.py 2007-11-23 00:02:50 UTC (rev 3711) +++ trunk/jython/Lib/test/regrtest.py 2007-11-23 00:36:30 UTC (rev 3712) @@ -1046,18 +1046,15 @@ test_pep277 test_pickle test_pickletools - test_pkg test_pkgimport test_posixpath test_profilehooks test_pyclbr test_quopri test_random - test_shutil test_slice test_softspace test_syntax - test_tempfile test_threaded_import test_trace test_ucn Added: trunk/jython/Lib/test/test_tempfile.py =================================================================== --- trunk/jython/Lib/test/test_tempfile.py (rev 0) +++ trunk/jython/Lib/test/test_tempfile.py 2007-11-23 00:36:30 UTC (rev 3712) @@ -0,0 +1,699 @@ +# From Python 2.5.1 +# tempfile.py unit tests. + +import tempfile +import os +import sys +import re +import errno +import warnings + +import unittest +from test import test_support + +warnings.filterwarnings("ignore", + category=RuntimeWarning, + message="mktemp", module=__name__) + +if hasattr(os, 'stat'): + import stat + has_stat = 1 +else: + has_stat = 0 + +has_textmode = (tempfile._text_openflags != tempfile._bin_openflags) +has_spawnl = hasattr(os, 'spawnl') + +# TEST_FILES may need to be tweaked for systems depending on the maximum +# number of files that can be opened at one time (see ulimit -n) +if sys.platform == 'mac': + TEST_FILES = 32 +elif sys.platform in ('openbsd3', 'openbsd4'): + TEST_FILES = 48 +else: + TEST_FILES = 100 + +# This is organized as one test for each chunk of code in tempfile.py, +# in order of their appearance in the file. Testing which requires +# threads is not done here. + +# Common functionality. +class TC(unittest.TestCase): + + str_check = re.compile(r"[a-zA-Z0-9_-]{6}$") + + def failOnException(self, what, ei=None): + if ei is None: + ei = sys.exc_info() + self.fail("%s raised %s: %s" % (what, ei[0], ei[1])) + + def nameCheck(self, name, dir, pre, suf): + (ndir, nbase) = os.path.split(name) + npre = nbase[:len(pre)] + nsuf = nbase[len(nbase)-len(suf):] + + # check for equality of the absolute paths! + self.assertEqual(os.path.abspath(ndir), os.path.abspath(dir), + "file '%s' not in directory '%s'" % (name, dir)) + self.assertEqual(npre, pre, + "file '%s' does not begin with '%s'" % (nbase, pre)) + self.assertEqual(nsuf, suf, + "file '%s' does not end with '%s'" % (nbase, suf)) + + nbase = nbase[len(pre):len(nbase)-len(suf)] + self.assert_(self.str_check.match(nbase), + "random string '%s' does not match /^[a-zA-Z0-9_-]{6}$/" + % nbase) + +test_classes = [] + +class test_exports(TC): + def test_exports(self): + # There are no surprising symbols in the tempfile module + dict = tempfile.__dict__ + + expected = { + "NamedTemporaryFile" : 1, + "TemporaryFile" : 1, + "mkstemp" : 1, + "mkdtemp" : 1, + "mktemp" : 1, + "TMP_MAX" : 1, + "gettempprefix" : 1, + "gettempdir" : 1, + "tempdir" : 1, + "template" : 1 + } + + unexp = [] + for key in dict: + if key[0] != '_' and key not in expected: + unexp.append(key) + self.failUnless(len(unexp) == 0, + "unexpected keys: %s" % unexp) + +test_classes.append(test_exports) + + +class test__RandomNameSequence(TC): + """Test the internal iterator object _RandomNameSequence.""" + + def setUp(self): + self.r = tempfile._RandomNameSequence() + + def test_get_six_char_str(self): + # _RandomNameSequence returns a six-character string + s = self.r.next() + self.nameCheck(s, '', '', '') + + def test_many(self): + # _RandomNameSequence returns no duplicate strings (stochastic) + + dict = {} + r = self.r + for i in xrange(TEST_FILES): + s = r.next() + self.nameCheck(s, '', '', '') + self.failIf(s in dict) + dict[s] = 1 + + def test_supports_iter(self): + # _RandomNameSequence supports the iterator protocol + + i = 0 + r = self.r + try: + for s in r: + i += 1 + if i == 20: + break + except: + failOnException("iteration") + +test_classes.append(test__RandomNameSequence) + + +class test__candidate_tempdir_list(TC): + """Test the internal function _candidate_tempdir_list.""" + + def test_nonempty_list(self): + # _candidate_tempdir_list returns a nonempty list of strings + + cand = tempfile._candidate_tempdir_list() + + self.failIf(len(cand) == 0) + for c in cand: + self.assert_(isinstance(c, basestring), + "%s is not a string" % c) + + def test_wanted_dirs(self): + # _candidate_tempdir_list contains the expected directories + + # Make sure the interesting environment variables are all set. + added = [] + try: + for envname in 'TMPDIR', 'TEMP', 'TMP': + dirname = os.getenv(envname) + if not dirname: + os.environ[envname] = os.path.abspath(envname) + added.append(envname) + + cand = tempfile._candidate_tempdir_list() + + for envname in 'TMPDIR', 'TEMP', 'TMP': + dirname = os.getenv(envname) + if not dirname: raise ValueError + self.assert_(dirname in cand) + + try: + dirname = os.getcwd() + except (AttributeError, os.error): + dirname = os.curdir + + self.assert_(dirname in cand) + + # Not practical to try to verify the presence of OS-specific + # paths in this list. + finally: + for p in added: + del os.environ[p] + +test_classes.append(test__candidate_tempdir_list) + + +# We test _get_default_tempdir by testing gettempdir. + + +class test__get_candidate_names(TC): + """Test the internal function _get_candidate_names.""" + + def test_retval(self): + # _get_candidate_names returns a _RandomNameSequence object + obj = tempfile._get_candidate_names() + self.assert_(isinstance(obj, tempfile._RandomNameSequence)) + + def test_same_thing(self): + # _get_candidate_names always returns the same object + a = tempfile._get_candidate_names() + b = tempfile._get_candidate_names() + + self.assert_(a is b) + +test_classes.append(test__get_candidate_names) + + +class test__mkstemp_inner(TC): + """Test the internal function _mkstemp_inner.""" + + class mkstemped: + _bflags = tempfile._bin_openflags + _tflags = tempfile._text_openflags + + def __init__(self, dir, pre, suf, bin): + if bin: flags = self._bflags + else: flags = self._tflags + + # XXX: CPython assigns _close/_unlink as class vars but this + # would rebind Jython's close/unlink (to be classmethods) + # because they're not built-in functions (unfortunately + # built-in functions act differently when binding: + # http://mail.python.org/pipermail/python-dev/2003-April/034749.html) + self._close = os.close + self._unlink = os.unlink + (self.fd, self.name) = tempfile._mkstemp_inner(dir, pre, suf, flags) + + def write(self, str): + os.write(self.fd, str) + # XXX: self.test_choose_directory expects the file to have been deleted + # (via __del__) by the time it's called, which is CPython specific + # garbage collection behavior. We need to delete it now in Jython + self._close(self.fd) + self._unlink(self.name) + + def __del__(self): + self._close(self.fd) + if os.path.exists(self.name): + self._unlink(self.name) + + def do_create(self, dir=None, pre="", suf="", bin=1): + if dir is None: + dir = tempfile.gettempdir() + try: + file = self.mkstemped(dir, pre, suf, bin) + except: + self.failOnException("_mkstemp_inner") + + self.nameCheck(file.name, dir, pre, suf) + return file + + def test_basic(self): + # _mkstemp_inner can create files + self.do_create().write("blat") + self.do_create(pre="a").write("blat") + self.do_create(suf="b").write("blat") + self.do_create(pre="a", suf="b").write("blat") + self.do_create(pre="aa", suf=".txt").write("blat") + + def test_basic_many(self): + # _mkstemp_inner can create many files (stochastic) + extant = range(TEST_FILES) + for i in extant: + extant[i] = self.do_create(pre="aa") + # XXX: Ensure mkstemped files are deleted (can't rely on Java's + # GC) + for i in extant: + i.__del__() + + def test_choose_directory(self): + # _mkstemp_inner can create files in a user-selected directory + dir = tempfile.mkdtemp() + try: + self.do_create(dir=dir).write("blat") + finally: + os.rmdir(dir) + + # XXX: Jython can't set the write mode yet + def _test_file_mode(self): + # _mkstemp_inner creates files with the proper mode + if not has_stat: + return # ugh, can't use TestSkipped. + + file = self.do_create() + mode = stat.S_IMODE(os.stat(file.name).st_mode) + expected = 0600 + if sys.platform in ('win32', 'os2emx', 'mac'): + # There's no distinction among 'user', 'group' and 'world'; + # replicate the 'user' bits. + user = expected >> 6 + expected = user * (1 + 8 + 64) + self.assertEqual(mode, expected) + + def test_noinherit(self): + # _mkstemp_inner file handles are not inherited by child processes + if not has_spawnl: + return # ugh, can't use TestSkipped. + + if test_support.verbose: + v="v" + else: + v="q" + + file = self.do_create() + fd = "%d" % file.fd + + try: + me = __file__ + except NameError: + me = sys.argv[0] + + # We have to exec something, so that FD_CLOEXEC will take + # effect. The core of this test is therefore in + # tf_inherit_check.py, which see. + tester = os.path.join(os.path.dirname(os.path.abspath(me)), + "tf_inherit_check.py") + + # On Windows a spawn* /path/ with embedded spaces shouldn't be quoted, + # but an arg with embedded spaces should be decorated with double + # quotes on each end + if sys.platform in ('win32'): + decorated = '"%s"' % sys.executable + tester = '"%s"' % tester + else: + decorated = sys.executable + + retval = os.spawnl(os.P_WAIT, sys.executable, decorated, tester, v, fd) + self.failIf(retval < 0, + "child process caught fatal signal %d" % -retval) + self.failIf(retval > 0, "child process reports failure %d"%retval) + + def test_textmode(self): + # _mkstemp_inner can create files in text mode + if not has_textmode: + return # ugh, can't use TestSkipped. + + self.do_create(bin=0).write("blat\n") + # XXX should test that the file really is a text file + +test_classes.append(test__mkstemp_inner) + + +class test_gettempprefix(TC): + """Test gettempprefix().""" + + def test_sane_template(self): + # gettempprefix returns a nonempty prefix string + p = tempfile.gettempprefix() + + self.assert_(isinstance(p, basestring)) + self.assert_(len(p) > 0) + + def test_usable_template(self): + # gettempprefix returns a usable prefix string + + # Create a temp directory, avoiding use of the prefix. + # Then attempt to create a file whose name is + # prefix + 'xxxxxx.xxx' in that directory. + p = tempfile.gettempprefix() + "xxxxxx.xxx" + d = tempfile.mkdtemp(prefix="") + try: + p = os.path.join(d, p) + try: + fd = os.open(p, os.O_RDWR | os.O_CREAT) + except: + self.failOnException("os.open") + os.close(fd) + os.unlink(p) + finally: + os.rmdir(d) + +test_classes.append(test_gettempprefix) + + +class test_gettempdir(TC): + """Test gettempdir().""" + + def test_directory_exists(self): + # gettempdir returns a directory which exists + + dir = tempfile.gettempdir() + self.assert_(os.path.isabs(dir) or dir == os.curdir, + "%s is not an absolute path" % dir) + self.assert_(os.path.isdir(dir), + "%s is not a directory" % dir) + + def test_directory_writable(self): + # gettempdir returns a directory writable by the user + + # sneaky: just instantiate a NamedTemporaryFile, which + # defaults to writing into the directory returned by + # gettempdir. + try: + file = tempfile.NamedTemporaryFile() + file.write("blat") + file.close() + except: + self.failOnException("create file in %s" % tempfile.gettempdir()) + + def test_same_thing(self): + # gettempdir always returns the same object + a = tempfile.gettempdir() + b = tempfile.gettempdir() + + self.assert_(a is b) + +test_classes.append(test_gettempdir) + + +class test_mkstemp(TC): + """Test mkstemp().""" + + def do_create(self, dir=None, pre="", suf=""): + if dir is None: + dir = tempfile.gettempdir() + try: + (fd, name) = tempfile.mkstemp(dir=dir, prefix=pre, suffix=suf) + (ndir, nbase) = os.path.split(name) + adir = os.path.abspath(dir) + self.assertEqual(adir, ndir, + "Directory '%s' incorrectly returned as '%s'" % (adir, ndir)) + except: + self.failOnException("mkstemp") + + try: + self.nameCheck(name, dir, pre, suf) + finally: + os.close(fd) + os.unlink(name) + + def test_basic(self): + # mkstemp can create files + self.do_create() + self.do_create(pre="a") + self.do_create(suf="b") + self.do_create(pre="a", suf="b") + self.do_create(pre="aa", suf=".txt") + self.do_create(dir=".") + + def test_choose_directory(self): + # mkstemp can create directories in a user-selected directory + dir = tempfile.mkdtemp() + try: + self.do_create(dir=dir) + finally: + os.rmdir(dir) + +test_classes.append(test_mkstemp) + + +class test_mkdtemp(TC): + """Test mkdtemp().""" + + def do_create(self, dir=None, pre="", suf=""): + if dir is None: + dir = tempfile.gettempdir() + try: + name = tempfile.mkdtemp(dir=dir, prefix=pre, suffix=suf) + except: + self.failOnException("mkdtemp") + + try: + self.nameCheck(name, dir, pre, suf) + return name + except: + os.rmdir(name) + raise + + def test_basic(self): + # mkdtemp can create directories + os.rmdir(self.do_create()) + os.rmdir(self.do_create(pre="a")) + os.rmdir(self.do_create(suf="b")) + os.rmdir(self.do_create(pre="a", suf="b")) + os.rmdir(self.do_create(pre="aa", suf=".txt")) + + def test_basic_many(self): + # mkdtemp can create many directories (stochastic) + extant = range(TEST_FILES) + try: + for i in extant: + extant[i] = self.do_create(pre="aa") + finally: + for i in extant: + if(isinstance(i, basestring)): + os.rmdir(i) + + def test_choose_directory(self): + # mkdtemp can create directories in a user-selected directory + dir = tempfile.mkdtemp() + try: + os.rmdir(self.do_create(dir=dir)) + finally: + os.rmdir(dir) + + def test_mode(self): + # mkdtemp creates directories with the proper mode + if not has_stat: + return # ugh, can't use TestSkipped. + if os.name == 'java': + # Java doesn't support stating files for permissions + return + + dir = self.do_create() + try: + mode = stat.S_IMODE(os.stat(dir).st_mode) + mode &= 0777 # Mask off sticky bits inherited from /tmp + expected = 0700 + if sys.platform in ('win32', 'os2emx', 'mac'): + # There's no distinction among 'user', 'group' and 'world'; + # replicate the 'user' bits. + user = expected >> 6 + expected = user * (1 + 8 + 64) + self.assertEqual(mode, expected) + finally: + os.rmdir(dir) + +test_classes.append(test_mkdtemp) + + +class test_mktemp(TC): + """Test mktemp().""" + + # For safety, all use of mktemp must occur in a private directory. + # We must also suppress the RuntimeWarning it generates. + def setUp(self): + self.dir = tempfile.mkdtemp() + + def tearDown(self): + if self.dir: + os.rmdir(self.dir) + self.dir = None + + class mktemped: + _bflags = tempfile._bin_openflags + + def __init__(self, dir, pre, suf): + # XXX: Assign _unlink here, instead of as a class var. See + # mkstemped.__init__ for an explanation + self._unlink = os.unlink + + self.name = tempfile.mktemp(dir=dir, prefix=pre, suffix=suf) + # Create the file. This will raise an exception if it's + # mysteriously appeared in the meanwhile. + os.close(os.open(self.name, self._bflags, 0600)) + # XXX: test_mktemp.tearDown expects the file to have been deleted + # (via __del__) by the time it's called, which is CPython specific + # garbage collection behavior. We need to delete it now in Jython + self._unlink(self.name) + + #def __del__(self): + # self._unlink(self.name) + + def do_create(self, pre="", suf=""): + try: + file = self.mktemped(self.dir, pre, suf) + except: + self.failOnException("mktemp") + + self.nameCheck(file.name, self.dir, pre, suf) + return file + + def test_basic(self): + # mktemp can choose usable file names + self.do_create() + self.do_create(pre="a") + self.do_create(suf="b") + self.do_create(pre="a", suf="b") + self.do_create(pre="aa", suf=".txt") + + def test_many(self): + # mktemp can choose many usable file names (stochastic) + extant = range(TEST_FILES) + for i in extant: + extant[i] = self.do_create(pre="aa") + +## def test_warning(self): +## # mktemp issues a warning when used +## warnings.filterwarnings("error", +## category=RuntimeWarning, +## message="mktemp") +## self.assertRaises(RuntimeWarning, +## tempfile.mktemp, dir=self.dir) + +test_classes.append(test_mktemp) + + +# We test _TemporaryFileWrapper by testing NamedTemporaryFile. + + +class test_NamedTemporaryFile(TC): + """Test NamedTemporaryFile().""" + + def do_create(self, dir=None, pre="", suf=""): + if dir is None: + dir = tempfile.gettempdir() + try: + file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf) + except: + self.failOnException("NamedTemporaryFile") + + self.nameCheck(file.name, dir, pre, suf) + return file + + + def test_basic(self): + # NamedTemporaryFile can create files + self.do_create() + self.do_create(pre="a") + self.do_create(suf="b") + self.do_create(pre="a", suf="b") + self.do_create(pre="aa", suf=".txt") + + def test_creates_named(self): + # NamedTemporaryFile creates files with names + f = tempfile.NamedTemporaryFile() + self.failUnless(os.path.exists(f.name), + "NamedTemporaryFile %s does not exist" % f.name) + + def test_del_on_close(self): + # A NamedTemporaryFile is deleted when closed + dir = tempfile.mkdtemp() + try: + f = tempfile.NamedTemporaryFile(dir=dir) + f.write('blat') + f.close() + self.failIf(os.path.exists(f.name), + "NamedTemporaryFile %s exists after close" % f.name) + finally: + os.rmdir(dir) + + def test_multiple_close(self): + # A NamedTemporaryFile can be closed many times without error + + f = tempfile.NamedTemporaryFile() + f.write('abc\n') + f.close() + try: + f.close() + f.close() + except: + self.failOnException("close") + + # How to test the mode and bufsize parameters? + +test_classes.append(test_NamedTemporaryFile) + + +class test_TemporaryFile(TC): + """Test TemporaryFile().""" + + def test_basic(self): + # TemporaryFile can create files + # No point in testing the name params - the file has no name. + try: + tempfile.TemporaryFile() + except: + self.failOnException("TemporaryFile") + + def test_has_no_name(self): + # TemporaryFile creates files with no names (on this system) + dir = tempfile.mkdtemp() + f = tempfile.TemporaryFile(dir=dir) + f.write('blat') + + # Sneaky: because this file has no name, it should not prevent + # us from removing the directory it was created in. + try: + os.rmdir(dir) + except: + ei = sys.exc_info() + # cleanup + f.close() + os.rmdir(dir) + self.failOnException("rmdir", ei) + + def test_multiple_close(self): + # A TemporaryFile can be closed many times without error + f = tempfile.TemporaryFile() + f.write('abc\n') + f.close() + try: + f.close() + f.close() + except: + self.failOnException("close") + + # How to test the mode and bufsize parameters? + + +if tempfile.NamedTemporaryFile is not tempfile.TemporaryFile: + test_classes.append(test_TemporaryFile) + +def test_main(): + test_support.run_unittest(*test_classes) + +if __name__ == "__main__": + test_main() + # XXX: Nudge Java's GC in an attempt to trigger any temp file's + # __del__ (cause them to be deleted) that hasn't been called + from java.lang import System + System.gc() This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <pj...@us...> - 2007-12-23 23:57:34
|
Revision: 3849 http://jython.svn.sourceforge.net/jython/?rev=3849&view=rev Author: pjenvey Date: 2007-12-23 15:57:32 -0800 (Sun, 23 Dec 2007) Log Message: ----------- add tarfile from CPython 2.5.1. we don't use CPython 2.3's because it relies on some obscure null byte handling in string atoi/atol that Jython doesn't do and isn't supported in later CPythons anyway Added Paths: ----------- trunk/jython/Lib/tarfile.py trunk/jython/Lib/test/cfgparser.1 trunk/jython/Lib/test/test_tarfile.py trunk/jython/Lib/test/testtar.tar Added: trunk/jython/Lib/tarfile.py =================================================================== --- trunk/jython/Lib/tarfile.py (rev 0) +++ trunk/jython/Lib/tarfile.py 2007-12-23 23:57:32 UTC (rev 3849) @@ -0,0 +1,2176 @@ +#!/usr/bin/env python +# -*- coding: iso-8859-1 -*- +#------------------------------------------------------------------- +# tarfile.py +#------------------------------------------------------------------- +# Copyright (C) 2002 Lars Gust\xE4bel <la...@gu...> +# All rights reserved. +# +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation +# files (the "Software"), to deal in the Software without +# restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following +# conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +# OTHER DEALINGS IN THE SOFTWARE. +# +"""Read from and write to tar format archives. +""" + +__version__ = "$Revision: 53162 $" +# $Source$ + +version = "0.8.0" +__author__ = "Lars Gust\xE4bel (la...@gu...)" +__date__ = "$Date: 2006-12-27 21:36:58 +1100 (Wed, 27 Dec 2006) $" +__cvsid__ = "$Id: tarfile.py 53162 2006-12-27 10:36:58Z lars.gustaebel $" +__credits__ = "Gustavo Niemeyer, Niels Gust\xE4bel, Richard Townsend." + +#--------- +# Imports +#--------- +import sys +import os +import shutil +import stat +import errno +import time +import struct +import copy + +if sys.platform == 'mac': + # This module needs work for MacOS9, especially in the area of pathname + # handling. In many places it is assumed a simple substitution of / by the + # local os.path.sep is good enough to convert pathnames, but this does not + # work with the mac rooted:path:name versus :nonrooted:path:name syntax + raise ImportError, "tarfile does not work for platform==mac" + +try: + import grp, pwd +except ImportError: + grp = pwd = None + +# from tarfile import * +__all__ = ["TarFile", "TarInfo", "is_tarfile", "TarError"] + +#--------------------------------------------------------- +# tar constants +#--------------------------------------------------------- +NUL = "\0" # the null character +BLOCKSIZE = 512 # length of processing blocks +RECORDSIZE = BLOCKSIZE * 20 # length of records +MAGIC = "ustar" # magic tar string +VERSION = "00" # version number + +LENGTH_NAME = 100 # maximum length of a filename +LENGTH_LINK = 100 # maximum length of a linkname +LENGTH_PREFIX = 155 # maximum length of the prefix field +MAXSIZE_MEMBER = 077777777777L # maximum size of a file (11 octal digits) + +REGTYPE = "0" # regular file +AREGTYPE = "\0" # regular file +LNKTYPE = "1" # link (inside tarfile) +SYMTYPE = "2" # symbolic link +CHRTYPE = "3" # character special device +BLKTYPE = "4" # block special device +DIRTYPE = "5" # directory +FIFOTYPE = "6" # fifo special device +CONTTYPE = "7" # contiguous file + +GNUTYPE_LONGNAME = "L" # GNU tar extension for longnames +GNUTYPE_LONGLINK = "K" # GNU tar extension for longlink +GNUTYPE_SPARSE = "S" # GNU tar extension for sparse file + +#--------------------------------------------------------- +# tarfile constants +#--------------------------------------------------------- +SUPPORTED_TYPES = (REGTYPE, AREGTYPE, LNKTYPE, # file types that tarfile + SYMTYPE, DIRTYPE, FIFOTYPE, # can cope with. + CONTTYPE, CHRTYPE, BLKTYPE, + GNUTYPE_LONGNAME, GNUTYPE_LONGLINK, + GNUTYPE_SPARSE) + +REGULAR_TYPES = (REGTYPE, AREGTYPE, # file types that somehow + CONTTYPE, GNUTYPE_SPARSE) # represent regular files + +#--------------------------------------------------------- +# Bits used in the mode field, values in octal. +#--------------------------------------------------------- +S_IFLNK = 0120000 # symbolic link +S_IFREG = 0100000 # regular file +S_IFBLK = 0060000 # block device +S_IFDIR = 0040000 # directory +S_IFCHR = 0020000 # character device +S_IFIFO = 0010000 # fifo + +TSUID = 04000 # set UID on execution +TSGID = 02000 # set GID on execution +TSVTX = 01000 # reserved + +TUREAD = 0400 # read by owner +TUWRITE = 0200 # write by owner +TUEXEC = 0100 # execute/search by owner +TGREAD = 0040 # read by group +TGWRITE = 0020 # write by group +TGEXEC = 0010 # execute/search by group +TOREAD = 0004 # read by other +TOWRITE = 0002 # write by other +TOEXEC = 0001 # execute/search by other + +#--------------------------------------------------------- +# Some useful functions +#--------------------------------------------------------- + +def stn(s, length): + """Convert a python string to a null-terminated string buffer. + """ + return s[:length] + (length - len(s)) * NUL + +def nti(s): + """Convert a number field to a python number. + """ + # There are two possible encodings for a number field, see + # itn() below. + if s[0] != chr(0200): + n = int(s.rstrip(NUL + " ") or "0", 8) + else: + n = 0L + for i in xrange(len(s) - 1): + n <<= 8 + n += ord(s[i + 1]) + return n + +def itn(n, digits=8, posix=False): + """Convert a python number to a number field. + """ + # POSIX 1003.1-1988 requires numbers to be encoded as a string of + # octal digits followed by a null-byte, this allows values up to + # (8**(digits-1))-1. GNU tar allows storing numbers greater than + # that if necessary. A leading 0200 byte indicates this particular + # encoding, the following digits-1 bytes are a big-endian + # representation. This allows values up to (256**(digits-1))-1. + if 0 <= n < 8 ** (digits - 1): + s = "%0*o" % (digits - 1, n) + NUL + else: + if posix: + raise ValueError("overflow in number field") + + if n < 0: + # XXX We mimic GNU tar's behaviour with negative numbers, + # this could raise OverflowError. + n = struct.unpack("L", struct.pack("l", n))[0] + + s = "" + for i in xrange(digits - 1): + s = chr(n & 0377) + s + n >>= 8 + s = chr(0200) + s + return s + +def calc_chksums(buf): + """Calculate the checksum for a member's header by summing up all + characters except for the chksum field which is treated as if + it was filled with spaces. According to the GNU tar sources, + some tars (Sun and NeXT) calculate chksum with signed char, + which will be different if there are chars in the buffer with + the high bit set. So we calculate two checksums, unsigned and + signed. + """ + unsigned_chksum = 256 + sum(struct.unpack("148B", buf[:148]) + struct.unpack("356B", buf[156:512])) + signed_chksum = 256 + sum(struct.unpack("148b", buf[:148]) + struct.unpack("356b", buf[156:512])) + return unsigned_chksum, signed_chksum + +def copyfileobj(src, dst, length=None): + """Copy length bytes from fileobj src to fileobj dst. + If length is None, copy the entire content. + """ + if length == 0: + return + if length is None: + shutil.copyfileobj(src, dst) + return + + BUFSIZE = 16 * 1024 + blocks, remainder = divmod(length, BUFSIZE) + for b in xrange(blocks): + buf = src.read(BUFSIZE) + if len(buf) < BUFSIZE: + raise IOError("end of file reached") + dst.write(buf) + + if remainder != 0: + buf = src.read(remainder) + if len(buf) < remainder: + raise IOError("end of file reached") + dst.write(buf) + return + +filemode_table = ( + ((S_IFLNK, "l"), + (S_IFREG, "-"), + (S_IFBLK, "b"), + (S_IFDIR, "d"), + (S_IFCHR, "c"), + (S_IFIFO, "p")), + + ((TUREAD, "r"),), + ((TUWRITE, "w"),), + ((TUEXEC|TSUID, "s"), + (TSUID, "S"), + (TUEXEC, "x")), + + ((TGREAD, "r"),), + ((TGWRITE, "w"),), + ((TGEXEC|TSGID, "s"), + (TSGID, "S"), + (TGEXEC, "x")), + + ((TOREAD, "r"),), + ((TOWRITE, "w"),), + ((TOEXEC|TSVTX, "t"), + (TSVTX, "T"), + (TOEXEC, "x")) +) + +def filemode(mode): + """Convert a file's mode to a string of the form + -rwxrwxrwx. + Used by TarFile.list() + """ + perm = [] + for table in filemode_table: + for bit, char in table: + if mode & bit == bit: + perm.append(char) + break + else: + perm.append("-") + return "".join(perm) + +if os.sep != "/": + normpath = lambda path: os.path.normpath(path).replace(os.sep, "/") +else: + normpath = os.path.normpath + +class TarError(Exception): + """Base exception.""" + pass +class ExtractError(TarError): + """General exception for extract errors.""" + pass +class ReadError(TarError): + """Exception for unreadble tar archives.""" + pass +class CompressionError(TarError): + """Exception for unavailable compression methods.""" + pass +class StreamError(TarError): + """Exception for unsupported operations on stream-like TarFiles.""" + pass + +#--------------------------- +# internal stream interface +#--------------------------- +class _LowLevelFile: + """Low-level file object. Supports reading and writing. + It is used instead of a regular file object for streaming + access. + """ + + def __init__(self, name, mode): + mode = { + "r": os.O_RDONLY, + "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC, + }[mode] + if hasattr(os, "O_BINARY"): + mode |= os.O_BINARY + self.fd = os.open(name, mode) + + def close(self): + os.close(self.fd) + + def read(self, size): + return os.read(self.fd, size) + + def write(self, s): + os.write(self.fd, s) + +class _Stream: + """Class that serves as an adapter between TarFile and + a stream-like object. The stream-like object only + needs to have a read() or write() method and is accessed + blockwise. Use of gzip or bzip2 compression is possible. + A stream-like object could be for example: sys.stdin, + sys.stdout, a socket, a tape device etc. + + _Stream is intended to be used only internally. + """ + + def __init__(self, name, mode, comptype, fileobj, bufsize): + """Construct a _Stream object. + """ + self._extfileobj = True + if fileobj is None: + fileobj = _LowLevelFile(name, mode) + self._extfileobj = False + + if comptype == '*': + # Enable transparent compression detection for the + # stream interface + fileobj = _StreamProxy(fileobj) + comptype = fileobj.getcomptype() + + self.name = name or "" + self.mode = mode + self.comptype = comptype + self.fileobj = fileobj + self.bufsize = bufsize + self.buf = "" + self.pos = 0L + self.closed = False + + if comptype == "gz": + try: + import zlib + except ImportError: + raise CompressionError("zlib module is not available") + self.zlib = zlib + self.crc = zlib.crc32("") + if mode == "r": + self._init_read_gz() + else: + self._init_write_gz() + + if comptype == "bz2": + try: + import bz2 + except ImportError: + raise CompressionError("bz2 module is not available") + if mode == "r": + self.dbuf = "" + self.cmp = bz2.BZ2Decompressor() + else: + self.cmp = bz2.BZ2Compressor() + + def __del__(self): + if hasattr(self, "closed") and not self.closed: + self.close() + + def _init_write_gz(self): + """Initialize for writing with gzip compression. + """ + self.cmp = self.zlib.compressobj(9, self.zlib.DEFLATED, + -self.zlib.MAX_WBITS, + self.zlib.DEF_MEM_LEVEL, + 0) + timestamp = struct.pack("<L", long(time.time())) + self.__write("\037\213\010\010%s\002\377" % timestamp) + if self.name.endswith(".gz"): + self.name = self.name[:-3] + self.__write(self.name + NUL) + + def write(self, s): + """Write string s to the stream. + """ + if self.comptype == "gz": + self.crc = self.zlib.crc32(s, self.crc) + self.pos += len(s) + if self.comptype != "tar": + s = self.cmp.compress(s) + self.__write(s) + + def __write(self, s): + """Write string s to the stream if a whole new block + is ready to be written. + """ + self.buf += s + while len(self.buf) > self.bufsize: + self.fileobj.write(self.buf[:self.bufsize]) + self.buf = self.buf[self.bufsize:] + + def close(self): + """Close the _Stream object. No operation should be + done on it afterwards. + """ + if self.closed: + return + + if self.mode == "w" and self.comptype != "tar": + self.buf += self.cmp.flush() + + if self.mode == "w" and self.buf: + self.fileobj.write(self.buf) + self.buf = "" + if self.comptype == "gz": + # The native zlib crc is an unsigned 32-bit integer, but + # the Python wrapper implicitly casts that to a signed C + # long. So, on a 32-bit box self.crc may "look negative", + # while the same crc on a 64-bit box may "look positive". + # To avoid irksome warnings from the `struct` module, force + # it to look positive on all boxes. + self.fileobj.write(struct.pack("<L", self.crc & 0xffffffffL)) + self.fileobj.write(struct.pack("<L", self.pos & 0xffffFFFFL)) + + if not self._extfileobj: + self.fileobj.close() + + self.closed = True + + def _init_read_gz(self): + """Initialize for reading a gzip compressed fileobj. + """ + self.cmp = self.zlib.decompressobj(-self.zlib.MAX_WBITS) + self.dbuf = "" + + # taken from gzip.GzipFile with some alterations + if self.__read(2) != "\037\213": + raise ReadError("not a gzip file") + if self.__read(1) != "\010": + raise CompressionError("unsupported compression method") + + flag = ord(self.__read(1)) + self.__read(6) + + if flag & 4: + xlen = ord(self.__read(1)) + 256 * ord(self.__read(1)) + self.read(xlen) + if flag & 8: + while True: + s = self.__read(1) + if not s or s == NUL: + break + if flag & 16: + while True: + s = self.__read(1) + if not s or s == NUL: + break + if flag & 2: + self.__read(2) + + def tell(self): + """Return the stream's file pointer position. + """ + return self.pos + + def seek(self, pos=0): + """Set the stream's file pointer to pos. Negative seeking + is forbidden. + """ + if pos - self.pos >= 0: + blocks, remainder = divmod(pos - self.pos, self.bufsize) + for i in xrange(blocks): + self.read(self.bufsize) + self.read(remainder) + else: + raise StreamError("seeking backwards is not allowed") + return self.pos + + def read(self, size=None): + """Return the next size number of bytes from the stream. + If size is not defined, return all bytes of the stream + up to EOF. + """ + if size is None: + t = [] + while True: + buf = self._read(self.bufsize) + if not buf: + break + t.append(buf) + buf = "".join(t) + else: + buf = self._read(size) + self.pos += len(buf) + return buf + + def _read(self, size): + """Return size bytes from the stream. + """ + if self.comptype == "tar": + return self.__read(size) + + c = len(self.dbuf) + t = [self.dbuf] + while c < size: + buf = self.__read(self.bufsize) + if not buf: + break + buf = self.cmp.decompress(buf) + t.append(buf) + c += len(buf) + t = "".join(t) + self.dbuf = t[size:] + return t[:size] + + def __read(self, size): + """Return size bytes from stream. If internal buffer is empty, + read another block from the stream. + """ + c = len(self.buf) + t = [self.buf] + while c < size: + buf = self.fileobj.read(self.bufsize) + if not buf: + break + t.append(buf) + c += len(buf) + t = "".join(t) + self.buf = t[size:] + return t[:size] +# class _Stream + +class _StreamProxy(object): + """Small proxy class that enables transparent compression + detection for the Stream interface (mode 'r|*'). + """ + + def __init__(self, fileobj): + self.fileobj = fileobj + self.buf = self.fileobj.read(BLOCKSIZE) + + def read(self, size): + self.read = self.fileobj.read + return self.buf + + def getcomptype(self): + if self.buf.startswith("\037\213\010"): + return "gz" + if self.buf.startswith("BZh91"): + return "bz2" + return "tar" + + def close(self): + self.fileobj.close() +# class StreamProxy + +class _BZ2Proxy(object): + """Small proxy class that enables external file object + support for "r:bz2" and "w:bz2" modes. This is actually + a workaround for a limitation in bz2 module's BZ2File + class which (unlike gzip.GzipFile) has no support for + a file object argument. + """ + + blocksize = 16 * 1024 + + def __init__(self, fileobj, mode): + self.fileobj = fileobj + self.mode = mode + self.init() + + def init(self): + import bz2 + self.pos = 0 + if self.mode == "r": + self.bz2obj = bz2.BZ2Decompressor() + self.fileobj.seek(0) + self.buf = "" + else: + self.bz2obj = bz2.BZ2Compressor() + + def read(self, size): + b = [self.buf] + x = len(self.buf) + while x < size: + try: + raw = self.fileobj.read(self.blocksize) + data = self.bz2obj.decompress(raw) + b.append(data) + except EOFError: + break + x += len(data) + self.buf = "".join(b) + + buf = self.buf[:size] + self.buf = self.buf[size:] + self.pos += len(buf) + return buf + + def seek(self, pos): + if pos < self.pos: + self.init() + self.read(pos - self.pos) + + def tell(self): + return self.pos + + def write(self, data): + self.pos += len(data) + raw = self.bz2obj.compress(data) + self.fileobj.write(raw) + + def close(self): + if self.mode == "w": + raw = self.bz2obj.flush() + self.fileobj.write(raw) + self.fileobj.close() +# class _BZ2Proxy + +#------------------------ +# Extraction file object +#------------------------ +class _FileInFile(object): + """A thin wrapper around an existing file object that + provides a part of its data as an individual file + object. + """ + + def __init__(self, fileobj, offset, size, sparse=None): + self.fileobj = fileobj + self.offset = offset + self.size = size + self.sparse = sparse + self.position = 0 + + def tell(self): + """Return the current file position. + """ + return self.position + + def seek(self, position): + """Seek to a position in the file. + """ + self.position = position + + def read(self, size=None): + """Read data from the file. + """ + if size is None: + size = self.size - self.position + else: + size = min(size, self.size - self.position) + + if self.sparse is None: + return self.readnormal(size) + else: + return self.readsparse(size) + + def readnormal(self, size): + """Read operation for regular files. + """ + self.fileobj.seek(self.offset + self.position) + self.position += size + return self.fileobj.read(size) + + def readsparse(self, size): + """Read operation for sparse files. + """ + data = [] + while size > 0: + buf = self.readsparsesection(size) + if not buf: + break + size -= len(buf) + data.append(buf) + return "".join(data) + + def readsparsesection(self, size): + """Read a single section of a sparse file. + """ + section = self.sparse.find(self.position) + + if section is None: + return "" + + size = min(size, section.offset + section.size - self.position) + + if isinstance(section, _data): + realpos = section.realpos + self.position - section.offset + self.fileobj.seek(self.offset + realpos) + self.position += size + return self.fileobj.read(size) + else: + self.position += size + return NUL * size +#class _FileInFile + + +class ExFileObject(object): + """File-like object for reading an archive member. + Is returned by TarFile.extractfile(). + """ + blocksize = 1024 + + def __init__(self, tarfile, tarinfo): + self.fileobj = _FileInFile(tarfile.fileobj, + tarinfo.offset_data, + tarinfo.size, + getattr(tarinfo, "sparse", None)) + self.name = tarinfo.name + self.mode = "r" + self.closed = False + self.size = tarinfo.size + + self.position = 0 + self.buffer = "" + + def read(self, size=None): + """Read at most size bytes from the file. If size is not + present or None, read all data until EOF is reached. + """ + if self.closed: + raise ValueError("I/O operation on closed file") + + buf = "" + if self.buffer: + if size is None: + buf = self.buffer + self.buffer = "" + else: + buf = self.buffer[:size] + self.buffer = self.buffer[size:] + + if size is None: + buf += self.fileobj.read() + else: + buf += self.fileobj.read(size - len(buf)) + + self.position += len(buf) + return buf + + def readline(self, size=-1): + """Read one entire line from the file. If size is present + and non-negative, return a string with at most that + size, which may be an incomplete line. + """ + if self.closed: + raise ValueError("I/O operation on closed file") + + if "\n" in self.buffer: + pos = self.buffer.find("\n") + 1 + else: + buffers = [self.buffer] + while True: + buf = self.fileobj.read(self.blocksize) + buffers.append(buf) + if not buf or "\n" in buf: + self.buffer = "".join(buffers) + pos = self.buffer.find("\n") + 1 + if pos == 0: + # no newline found. + pos = len(self.buffer) + break + + if size != -1: + pos = min(size, pos) + + buf = self.buffer[:pos] + self.buffer = self.buffer[pos:] + self.position += len(buf) + return buf + + def readlines(self): + """Return a list with all remaining lines. + """ + result = [] + while True: + line = self.readline() + if not line: break + result.append(line) + return result + + def tell(self): + """Return the current file position. + """ + if self.closed: + raise ValueError("I/O operation on closed file") + + return self.position + + def seek(self, pos, whence=os.SEEK_SET): + """Seek to a position in the file. + """ + if self.closed: + raise ValueError("I/O operation on closed file") + + if whence == os.SEEK_SET: + self.position = min(max(pos, 0), self.size) + elif whence == os.SEEK_CUR: + if pos < 0: + self.position = max(self.position + pos, 0) + else: + self.position = min(self.position + pos, self.size) + elif whence == os.SEEK_END: + self.position = max(min(self.size + pos, self.size), 0) + else: + raise ValueError("Invalid argument") + + self.buffer = "" + self.fileobj.seek(self.position) + + def close(self): + """Close the file object. + """ + self.closed = True + + def __iter__(self): + """Get an iterator over the file's lines. + """ + while True: + line = self.readline() + if not line: + break + yield line +#class ExFileObject + +#------------------ +# Exported Classes +#------------------ +class TarInfo(object): + """Informational class which holds the details about an + archive member given by a tar header block. + TarInfo objects are returned by TarFile.getmember(), + TarFile.getmembers() and TarFile.gettarinfo() and are + usually created internally. + """ + + def __init__(self, name=""): + """Construct a TarInfo object. name is the optional name + of the member. + """ + self.name = name # member name (dirnames must end with '/') + self.mode = 0666 # file permissions + self.uid = 0 # user id + self.gid = 0 # group id + self.size = 0 # file size + self.mtime = 0 # modification time + self.chksum = 0 # header checksum + self.type = REGTYPE # member type + self.linkname = "" # link name + self.uname = "user" # user name + self.gname = "group" # group name + self.devmajor = 0 # device major number + self.devminor = 0 # device minor number + + self.offset = 0 # the tar header starts here + self.offset_data = 0 # the file's data starts here + + def __repr__(self): + return "<%s %r at %#x>" % (self.__class__.__name__,self.name,id(self)) + + @classmethod + def frombuf(cls, buf): + """Construct a TarInfo object from a 512 byte string buffer. + """ + if len(buf) != BLOCKSIZE: + raise ValueError("truncated header") + if buf.count(NUL) == BLOCKSIZE: + raise ValueError("empty header") + + tarinfo = cls() + tarinfo.buf = buf + tarinfo.name = buf[0:100].rstrip(NUL) + tarinfo.mode = nti(buf[100:108]) + tarinfo.uid = nti(buf[108:116]) + tarinfo.gid = nti(buf[116:124]) + tarinfo.size = nti(buf[124:136]) + tarinfo.mtime = nti(buf[136:148]) + tarinfo.chksum = nti(buf[148:156]) + tarinfo.type = buf[156:157] + tarinfo.linkname = buf[157:257].rstrip(NUL) + tarinfo.uname = buf[265:297].rstrip(NUL) + tarinfo.gname = buf[297:329].rstrip(NUL) + tarinfo.devmajor = nti(buf[329:337]) + tarinfo.devminor = nti(buf[337:345]) + prefix = buf[345:500].rstrip(NUL) + + if prefix and not tarinfo.issparse(): + tarinfo.name = prefix + "/" + tarinfo.name + + if tarinfo.chksum not in calc_chksums(buf): + raise ValueError("invalid header") + return tarinfo + + def tobuf(self, posix=False): + """Return a tar header as a string of 512 byte blocks. + """ + buf = "" + type = self.type + prefix = "" + + if self.name.endswith("/"): + type = DIRTYPE + + if type in (GNUTYPE_LONGNAME, GNUTYPE_LONGLINK): + # Prevent "././@LongLink" from being normalized. + name = self.name + else: + name = normpath(self.name) + + if type == DIRTYPE: + # directories should end with '/' + name += "/" + + linkname = self.linkname + if linkname: + # if linkname is empty we end up with a '.' + linkname = normpath(linkname) + + if posix: + if self.size > MAXSIZE_MEMBER: + raise ValueError("file is too large (>= 8 GB)") + + if len(self.linkname) > LENGTH_LINK: + raise ValueError("linkname is too long (>%d)" % (LENGTH_LINK)) + + if len(name) > LENGTH_NAME: + prefix = name[:LENGTH_PREFIX + 1] + while prefix and prefix[-1] != "/": + prefix = prefix[:-1] + + name = name[len(prefix):] + prefix = prefix[:-1] + + if not prefix or len(name) > LENGTH_NAME: + raise ValueError("name is too long") + + else: + if len(self.linkname) > LENGTH_LINK: + buf += self._create_gnulong(self.linkname, GNUTYPE_LONGLINK) + + if len(name) > LENGTH_NAME: + buf += self._create_gnulong(name, GNUTYPE_LONGNAME) + + parts = [ + stn(name, 100), + itn(self.mode & 07777, 8, posix), + itn(self.uid, 8, posix), + itn(self.gid, 8, posix), + itn(self.size, 12, posix), + itn(self.mtime, 12, posix), + " ", # checksum field + type, + stn(self.linkname, 100), + stn(MAGIC, 6), + stn(VERSION, 2), + stn(self.uname, 32), + stn(self.gname, 32), + itn(self.devmajor, 8, posix), + itn(self.devminor, 8, posix), + stn(prefix, 155) + ] + + buf += struct.pack("%ds" % BLOCKSIZE, "".join(parts)) + chksum = calc_chksums(buf[-BLOCKSIZE:])[0] + buf = buf[:-364] + "%06o\0" % chksum + buf[-357:] + self.buf = buf + return buf + + def _create_gnulong(self, name, type): + """Create a GNU longname/longlink header from name. + It consists of an extended tar header, with the length + of the longname as size, followed by data blocks, + which contain the longname as a null terminated string. + """ + name += NUL + + tarinfo = self.__class__() + tarinfo.name = "././@LongLink" + tarinfo.type = type + tarinfo.mode = 0 + tarinfo.size = len(name) + + # create extended header + buf = tarinfo.tobuf() + # create name blocks + buf += name + blocks, remainder = divmod(len(name), BLOCKSIZE) + if remainder > 0: + buf += (BLOCKSIZE - remainder) * NUL + return buf + + def isreg(self): + return self.type in REGULAR_TYPES + def isfile(self): + return self.isreg() + def isdir(self): + return self.type == DIRTYPE + def issym(self): + return self.type == SYMTYPE + def islnk(self): + return self.type == LNKTYPE + def ischr(self): + return self.type == CHRTYPE + def isblk(self): + return self.type == BLKTYPE + def isfifo(self): + return self.type == FIFOTYPE + def issparse(self): + return self.type == GNUTYPE_SPARSE + def isdev(self): + return self.type in (CHRTYPE, BLKTYPE, FIFOTYPE) +# class TarInfo + +class TarFile(object): + """The TarFile Class provides an interface to tar archives. + """ + + debug = 0 # May be set from 0 (no msgs) to 3 (all msgs) + + dereference = False # If true, add content of linked file to the + # tar file, else the link. + + ignore_zeros = False # If true, skips empty or invalid blocks and + # continues processing. + + errorlevel = 0 # If 0, fatal errors only appear in debug + # messages (if debug >= 0). If > 0, errors + # are passed to the caller as exceptions. + + posix = False # If True, generates POSIX.1-1990-compliant + # archives (no GNU extensions!) + + fileobject = ExFileObject + + def __init__(self, name=None, mode="r", fileobj=None): + """Open an (uncompressed) tar archive `name'. `mode' is either 'r' to + read from an existing archive, 'a' to append data to an existing + file or 'w' to create a new file overwriting an existing one. `mode' + defaults to 'r'. + If `fileobj' is given, it is used for reading or writing data. If it + can be determined, `mode' is overridden by `fileobj's mode. + `fileobj' is not closed, when TarFile is closed. + """ + self.name = os.path.abspath(name) + + if len(mode) > 1 or mode not in "raw": + raise ValueError("mode must be 'r', 'a' or 'w'") + self._mode = mode + self.mode = {"r": "rb", "a": "r+b", "w": "wb"}[mode] + + if not fileobj: + fileobj = file(self.name, self.mode) + self._extfileobj = False + else: + if self.name is None and hasattr(fileobj, "name"): + self.name = os.path.abspath(fileobj.name) + if hasattr(fileobj, "mode"): + self.mode = fileobj.mode + self._extfileobj = True + self.fileobj = fileobj + + # Init datastructures + self.closed = False + self.members = [] # list of members as TarInfo objects + self._loaded = False # flag if all members have been read + self.offset = 0L # current position in the archive file + self.inodes = {} # dictionary caching the inodes of + # archive members already added + + if self._mode == "r": + self.firstmember = None + self.firstmember = self.next() + + if self._mode == "a": + # Move to the end of the archive, + # before the first empty block. + self.firstmember = None + while True: + try: + tarinfo = self.next() + except ReadError: + self.fileobj.seek(0) + break + if tarinfo is None: + self.fileobj.seek(- BLOCKSIZE, 1) + break + + if self._mode in "aw": + self._loaded = True + + #-------------------------------------------------------------------------- + # Below are the classmethods which act as alternate constructors to the + # TarFile class. The open() method is the only one that is needed for + # public use; it is the "super"-constructor and is able to select an + # adequate "sub"-constructor for a particular compression using the mapping + # from OPEN_METH. + # + # This concept allows one to subclass TarFile without losing the comfort of + # the super-constructor. A sub-constructor is registered and made available + # by adding it to the mapping in OPEN_METH. + + @classmethod + def open(cls, name=None, mode="r", fileobj=None, bufsize=20*512): + """Open a tar archive for reading, writing or appending. Return + an appropriate TarFile class. + + mode: + 'r' or 'r:*' open for reading with transparent compression + 'r:' open for reading exclusively uncompressed + 'r:gz' open for reading with gzip compression + 'r:bz2' open for reading with bzip2 compression + 'a' or 'a:' open for appending + 'w' or 'w:' open for writing without compression + 'w:gz' open for writing with gzip compression + 'w:bz2' open for writing with bzip2 compression + + 'r|*' open a stream of tar blocks with transparent compression + 'r|' open an uncompressed stream of tar blocks for reading + 'r|gz' open a gzip compressed stream of tar blocks + 'r|bz2' open a bzip2 compressed stream of tar blocks + 'w|' open an uncompressed stream for writing + 'w|gz' open a gzip compressed stream for writing + 'w|bz2' open a bzip2 compressed stream for writing + """ + + if not name and not fileobj: + raise ValueError("nothing to open") + + if mode in ("r", "r:*"): + # Find out which *open() is appropriate for opening the file. + for comptype in cls.OPEN_METH: + func = getattr(cls, cls.OPEN_METH[comptype]) + if fileobj is not None: + saved_pos = fileobj.tell() + try: + return func(name, "r", fileobj) + except (ReadError, CompressionError): + if fileobj is not None: + fileobj.seek(saved_pos) + continue + raise ReadError("file could not be opened successfully") + + elif ":" in mode: + filemode, comptype = mode.split(":", 1) + filemode = filemode or "r" + comptype = comptype or "tar" + + # Select the *open() function according to + # given compression. + if comptype in cls.OPEN_METH: + func = getattr(cls, cls.OPEN_METH[comptype]) + else: + raise CompressionError("unknown compression type %r" % comptype) + return func(name, filemode, fileobj) + + elif "|" in mode: + filemode, comptype = mode.split("|", 1) + filemode = filemode or "r" + comptype = comptype or "tar" + + if filemode not in "rw": + raise ValueError("mode must be 'r' or 'w'") + + t = cls(name, filemode, + _Stream(name, filemode, comptype, fileobj, bufsize)) + t._extfileobj = False + return t + + elif mode in "aw": + return cls.taropen(name, mode, fileobj) + + raise ValueError("undiscernible mode") + + @classmethod + def taropen(cls, name, mode="r", fileobj=None): + """Open uncompressed tar archive name for reading or writing. + """ + if len(mode) > 1 or mode not in "raw": + raise ValueError("mode must be 'r', 'a' or 'w'") + return cls(name, mode, fileobj) + + @classmethod + def gzopen(cls, name, mode="r", fileobj=None, compresslevel=9): + """Open gzip compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if len(mode) > 1 or mode not in "rw": + raise ValueError("mode must be 'r' or 'w'") + + try: + import gzip + gzip.GzipFile + except (ImportError, AttributeError): + raise CompressionError("gzip module is not available") + + if fileobj is None: + fileobj = file(name, mode + "b") + + try: + t = cls.taropen(name, mode, + gzip.GzipFile(name, mode, compresslevel, fileobj)) + except IOError: + raise ReadError("not a gzip file") + t._extfileobj = False + return t + + @classmethod + def bz2open(cls, name, mode="r", fileobj=None, compresslevel=9): + """Open bzip2 compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if len(mode) > 1 or mode not in "rw": + raise ValueError("mode must be 'r' or 'w'.") + + try: + import bz2 + except ImportError: + raise CompressionError("bz2 module is not available") + + if fileobj is not None: + fileobj = _BZ2Proxy(fileobj, mode) + else: + fileobj = bz2.BZ2File(name, mode, compresslevel=compresslevel) + + try: + t = cls.taropen(name, mode, fileobj) + except IOError: + raise ReadError("not a bzip2 file") + t._extfileobj = False + return t + + # All *open() methods are registered here. + OPEN_METH = { + "tar": "taropen", # uncompressed tar + "gz": "gzopen", # gzip compressed tar + "bz2": "bz2open" # bzip2 compressed tar + } + + #-------------------------------------------------------------------------- + # The public methods which TarFile provides: + + def close(self): + """Close the TarFile. In write-mode, two finishing zero blocks are + appended to the archive. + """ + if self.closed: + return + + if self._mode in "aw": + self.fileobj.write(NUL * (BLOCKSIZE * 2)) + self.offset += (BLOCKSIZE * 2) + # fill up the end with zero-blocks + # (like option -b20 for tar does) + blocks, remainder = divmod(self.offset, RECORDSIZE) + if remainder > 0: + self.fileobj.write(NUL * (RECORDSIZE - remainder)) + + if not self._extfileobj: + self.fileobj.close() + self.closed = True + + def getmember(self, name): + """Return a TarInfo object for member `name'. If `name' can not be + found in the archive, KeyError is raised. If a member occurs more + than once in the archive, its last occurence is assumed to be the + most up-to-date version. + """ + tarinfo = self._getmember(name) + if tarinfo is None: + raise KeyError("filename %r not found" % name) + return tarinfo + + def getmembers(self): + """Return the members of the archive as a list of TarInfo objects. The + list has the same order as the members in the archive. + """ + self._check() + if not self._loaded: # if we want to obtain a list of + self._load() # all members, we first have to + # scan the whole archive. + return self.members + + def getnames(self): + """Return the members of the archive as a list of their names. It has + the same order as the list returned by getmembers(). + """ + return [tarinfo.name for tarinfo in self.getmembers()] + + def gettarinfo(self, name=None, arcname=None, fileobj=None): + """Create a TarInfo object for either the file `name' or the file + object `fileobj' (using os.fstat on its file descriptor). You can + modify some of the TarInfo's attributes before you add it using + addfile(). If given, `arcname' specifies an alternative name for the + file in the archive. + """ + self._check("aw") + + # When fileobj is given, replace name by + # fileobj's real name. + if fileobj is not None: + name = fileobj.name + + # Building the name of the member in the archive. + # Backward slashes are converted to forward slashes, + # Absolute paths are turned to relative paths. + if arcname is None: + arcname = name + arcname = normpath(arcname) + drv, arcname = os.path.splitdrive(arcname) + while arcname[0:1] == "/": + arcname = arcname[1:] + + # Now, fill the TarInfo object with + # information specific for the file. + tarinfo = TarInfo() + + # Use os.stat or os.lstat, depending on platform + # and if symlinks shall be resolved. + if fileobj is None: + if hasattr(os, "lstat") and not self.dereference: + statres = os.lstat(name) + else: + statres = os.stat(name) + else: + statres = os.fstat(fileobj.fileno()) + linkname = "" + + stmd = statres.st_mode + if stat.S_ISREG(stmd): + inode = (statres.st_ino, statres.st_dev) + if not self.dereference and \ + statres.st_nlink > 1 and inode in self.inodes: + # Is it a hardlink to an already + # archived file? + type = LNKTYPE + linkname = self.inodes[inode] + else: + # The inode is added only if its valid. + # For win32 it is always 0. + type = REGTYPE + if inode[0]: + self.inodes[inode] = arcname + elif stat.S_ISDIR(stmd): + type = DIRTYPE + if arcname[-1:] != "/": + arcname += "/" + elif stat.S_ISFIFO(stmd): + type = FIFOTYPE + elif stat.S_ISLNK(stmd): + type = SYMTYPE + linkname = os.readlink(name) + elif stat.S_ISCHR(stmd): + type = CHRTYPE + elif stat.S_ISBLK(stmd): + type = BLKTYPE + else: + return None + + # Fill the TarInfo object with all + # information we can get. + tarinfo.name = arcname + tarinfo.mode = stmd + tarinfo.uid = statres.st_uid + tarinfo.gid = statres.st_gid + if stat.S_ISREG(stmd): + tarinfo.size = statres.st_size + else: + tarinfo.size = 0L + tarinfo.mtime = statres.st_mtime + tarinfo.type = type + tarinfo.linkname = linkname + if pwd: + try: + tarinfo.uname = pwd.getpwuid(tarinfo.uid)[0] + except KeyError: + pass + if grp: + try: + tarinfo.gname = grp.getgrgid(tarinfo.gid)[0] + except KeyError: + pass + + if type in (CHRTYPE, BLKTYPE): + if hasattr(os, "major") and hasattr(os, "minor"): + tarinfo.devmajor = os.major(statres.st_rdev) + tarinfo.devminor = os.minor(statres.st_rdev) + return tarinfo + + def list(self, verbose=True): + """Print a table of contents to sys.stdout. If `verbose' is False, only + the names of the members are printed. If it is True, an `ls -l'-like + output is produced. + """ + self._check() + + for tarinfo in self: + if verbose: + print filemode(tarinfo.mode), + print "%s/%s" % (tarinfo.uname or tarinfo.uid, + tarinfo.gname or tarinfo.gid), + if tarinfo.ischr() or tarinfo.isblk(): + print "%10s" % ("%d,%d" \ + % (tarinfo.devmajor, tarinfo.devminor)), + else: + print "%10d" % tarinfo.size, + print "%d-%02d-%02d %02d:%02d:%02d" \ + % time.localtime(tarinfo.mtime)[:6], + + print tarinfo.name, + + if verbose: + if tarinfo.issym(): + print "->", tarinfo.linkname, + if tarinfo.islnk(): + print "link to", tarinfo.linkname, + print + + def add(self, name, arcname=None, recursive=True): + """Add the file `name' to the archive. `name' may be any type of file + (directory, fifo, symbolic link, etc.). If given, `arcname' + specifies an alternative name for the file in the archive. + Directories are added recursively by default. This can be avoided by + setting `recursive' to False. + """ + self._check("aw") + + if arcname is None: + arcname = name + + # Skip if somebody tries to archive the archive... + if self.name is not None and os.path.abspath(name) == self.name: + self._dbg(2, "tarfile: Skipped %r" % name) + return + + # Special case: The user wants to add the current + # working directory. + if name == ".": + if recursive: + if arcname == ".": + arcname = "" + for f in os.listdir("."): + self.add(f, os.path.join(arcname, f)) + return + + self._dbg(1, name) + + # Create a TarInfo object from the file. + tarinfo = self.gettarinfo(name, arcname) + + if tarinfo is None: + self._dbg(1, "tarfile: Unsupported type %r" % name) + return + + # Append the tar header and data to the archive. + if tarinfo.isreg(): + f = file(name, "rb") + self.addfile(tarinfo, f) + f.close() + + elif tarinfo.isdir(): + self.addfile(tarinfo) + if recursive: + for f in os.listdir(name): + self.add(os.path.join(name, f), os.path.join(arcname, f)) + + else: + self.addfile(tarinfo) + + def addfile(self, tarinfo, fileobj=None): + """Add the TarInfo object `tarinfo' to the archive. If `fileobj' is + given, tarinfo.size bytes are read from it and added to the archive. + You can create TarInfo objects using gettarinfo(). + On Windows platforms, `fileobj' should always be opened with mode + 'rb' to avoid irritation about the file size. + """ + self._check("aw") + + tarinfo = copy.copy(tarinfo) + + buf = tarinfo.tobuf(self.posix) + self.fileobj.write(buf) + self.offset += len(buf) + + # If there's data to follow, append it. + if fileobj is not None: + copyfileobj(fileobj, self.fileobj, tarinfo.size) + blocks, remainder = divmod(tarinfo.size, BLOCKSIZE) + if remainder > 0: + self.fileobj.write(NUL * (BLOCKSIZE - remainder)) + blocks += 1 + self.offset += blocks * BLOCKSIZE + + self.members.append(tarinfo) + + def extractall(self, path=".", members=None): + """Extract all members from the archive to the current working + directory and set owner, modification time and permissions on + directories afterwards. `path' specifies a different directory + to extract to. `members' is optional and must be a subset of the + list returned by getmembers(). + """ + directories = [] + + if members is None: + members = self + + for tarinfo in members: + if tarinfo.isdir(): + # Extract directory with a safe mode, so that + # all files below can be extracted as well. + try: + os.makedirs(os.path.join(path, tarinfo.name), 0777) + except EnvironmentError: + pass + directories.append(tarinfo) + else: + self.extract(tarinfo, path) + + # Reverse sort directories. + directories.sort(lambda a, b: cmp(a.name, b.name)) + directories.reverse() + + # Set correct owner, mtime and filemode on directories. + for tarinfo in directories: + path = os.path.join(path, tarinfo.name) + try: + self.chown(tarinfo, path) + self.utime(tarinfo, path) + self.chmod(tarinfo, path) + except ExtractError, e: + if self.errorlevel > 1: + raise + else: + self._dbg(1, "tarfile: %s" % e) + + def extract(self, member, path=""): + """Extract a member from the archive to the current working directory, + using its full name. Its file information is extracted as accurately + as possible. `member' may be a filename or a TarInfo object. You can + specify a different directory using `path'. + """ + self._check("r") + + if isinstance(member, TarInfo): + tarinfo = member + else: + tarinfo = self.getmember(member) + + # Prepare the link target for makelink(). + if tarinfo.islnk(): + tarinfo._link_target = os.path.join(path, tarinfo.linkname) + + try: + self._extract_member(tarinfo, os.path.join(path, tarinfo.name)) + except EnvironmentError, e: + if self.errorlevel > 0: + raise + else: + if e.filename is None: + self._dbg(1, "tarfile: %s" % e.strerror) + else: + self._dbg(1, "tarfile: %s %r" % (e.strerror, e.filename)) + except ExtractError, e: + if self.errorlevel > 1: + raise + else: + self._dbg(1, "tarfile: %s" % e) + + def extractfile(self, member): + """Extract a member from the archive as a file object. `member' may be + a filename or a TarInfo object. If `member' is a regular file, a + file-like object is returned. If `member' is a link, a file-like + object is constructed from the link's target. If `member' is none of + the above, None is returned. + The file-like object is read-only and provides the following + methods: read(), readline(), readlines(), seek() and tell() + """ + self._check("r") + + if isinstance(member, TarInfo): + tarinfo = member + else: + tarinfo = self.getmember(member) + + if tarinfo.isreg(): + return self.fileobject(self, tarinfo) + + elif tarinfo.type not in SUPPORTED_TYPES: + # If a member's type is unknown, it is treated as a + # regular file. + return self.fileobject(self, tarinfo) + + elif tarinfo.islnk() or tarinfo.issym(): + if isinstance(self.fileobj, _Stream): + # A small but ugly workaround for the case that someone tries + # to extract a (sym)link as a file-object from a non-seekable + # stream of tar blocks. + raise StreamError("cannot extract (sym)link as file object") + else: + # A (sym)link's file object is its target's file object. + return self.extractfile(self._getmember(tarinfo.linkname, + tarinfo)) + else: + # If there's no data associated with the member (directory, chrdev, + # blkdev, etc.), return None instead of a file object. + return None + + def _extract_member(self, tarinfo, targetpath): + """Extract the TarInfo object tarinfo to a physical + file called targetpath. + """ + # Fetch the TarInfo object for the given name + # and build the destination pathname, replacing + # forward slashes to platform specific separators. + if targetpath[-1:] == "/": + targetpath = targetpath[:-1] + targetpath = os.path.normpath(targetpath) + + # Create all upper directories. + upperdirs = os.path.dirname(targetpath) + if upperdirs and not os.path.exists(upperdirs): + ti = TarInfo() + ti.name = upperdirs + ti.type = DIRTYPE + ti.mode = 0777 + ti.mtime = tarinfo.mtime + ti.uid = tarinfo.uid + ti.gid = tarinfo.gid + ti.uname = tarinfo.uname + ti.gname = tarinfo.gname + try: + self._extract_member(ti, ti.name) + except: + pass + + if tarinfo.islnk() or tarinfo.issym(): + self._dbg(1, "%s -> %s" % (tarinfo.name, tarinfo.linkname)) + else: + self._dbg(1, tarinfo.name) + + if tarinfo.isreg(): + self.makefile(tarinfo, targetpath) + elif tarinfo.isdir(): + self.makedir(tarinfo, targetpath) + elif tarinfo.isfifo(): + self.makefifo(tarinfo, targetpath) + elif tarinfo.ischr() or tarinfo.isblk(): + self.makedev(tarinfo, targetpath) + elif tarinfo.islnk() or tarinfo.issym(): + self.makelink(tarinfo, targetpath) + elif tarinfo.type not in SUPPORTED_TYPES: + self.makeunknown(tarinfo, targetpath) + else: + self.makefile(tarinfo, targetpath) + + self.chown(tarinfo, targetpath) + if not tarinfo.issym(): + self.chmod(tarinfo, targetpath) + self.utime(tarinfo, targetpath) + + #-------------------------------------------------------------------------- + # Below are the different file methods. They are called via + # _extract_member() when extract() is called. They can be replaced in a + # subclass to implement other functionality. + + def makedir(self, tarinfo, targetpath): + """Make a directory called targetpath. + """ + try: + os.mkdir(targetpath) + except EnvironmentError, e: + if e.errno != errno.EEXIST: + raise + + def makefile(self, tarinfo, targetpath): + """Make a file called targetpath. + """ + source = self.extractfile(tarinfo) + target = file(targetpath, "wb") + copyfileobj(source, target) + source.close() + target.close() + + def makeunknown(self, tarinfo, targetpath): + """Make a file from a TarInfo object with an unknown type + at targetpath. + """ + self.makefile(tarinfo, targetpath) + self._dbg(1, "tarfile: Unknown file type %r, " \ + "extracted as regular file." % tarinfo.type) + + def makefifo(self, tarinfo, targetpath): + """Make a fifo called targetpath. + """ + if hasattr(os, "mkfifo"): + os.mkfifo(targetpath) + else: + raise ExtractError("fifo not supported by system") + + def makedev(self, tarinfo, targetpath): + """Make a character or block device called targetpath. + """ + ... [truncated message content] |
From: <pj...@us...> - 2007-12-24 19:36:06
|
Revision: 3863 http://jython.svn.sourceforge.net/jython/?rev=3863&view=rev Author: pjenvey Date: 2007-12-24 11:36:03 -0800 (Mon, 24 Dec 2007) Log Message: ----------- restrict reference counting code to isinstance(sock, _nonblocking_api_mixin) instead not isinstance(sock, _closedsocket). allows urllib2's usage of _fileobject to wrap a non socket object fixes #1850722 thanks Daniel Menezes Modified Paths: -------------- trunk/jython/Lib/socket.py trunk/jython/Lib/test/test_socket.py Modified: trunk/jython/Lib/socket.py =================================================================== --- trunk/jython/Lib/socket.py 2007-12-24 19:01:06 UTC (rev 3862) +++ trunk/jython/Lib/socket.py 2007-12-24 19:36:03 UTC (rev 3863) @@ -799,7 +799,7 @@ if _sock is None: _sock = _realsocket(family, type, proto) _sock.reference_count += 1 - elif not isinstance(_sock, _closedsocket): + elif isinstance(_sock, _nonblocking_api_mixin): _sock.reference_count += 1 self._sock = _sock self.send = self._sock.send @@ -810,7 +810,7 @@ def close(self): _sock = self._sock - if not isinstance(_sock, _closedsocket): + if isinstance(_sock, _nonblocking_api_mixin): _sock.close_lock.acquire() try: _sock.reference_count -=1 @@ -833,7 +833,7 @@ Return a new socket object connected to the same system resource.""" _sock = self._sock - if isinstance(_sock, _closedsocket): + if not isinstance(_sock, _nonblocking_api_mixin): return _socketobject(_sock=_sock) _sock.close_lock.acquire() @@ -849,7 +849,7 @@ Return a regular file object corresponding to the socket. The mode and bufsize arguments are as for the built-in open() function.""" _sock = self._sock - if isinstance(_sock, _closedsocket): + if not isinstance(_sock, _nonblocking_api_mixin): return _fileobject(_sock, mode, bufsize) _sock.close_lock.acquire() @@ -882,7 +882,7 @@ def __init__(self, sock, mode='rb', bufsize=-1): self._sock = sock - if not isinstance(sock, _closedsocket): + if isinstance(sock, _nonblocking_api_mixin): sock.reference_count += 1 self.mode = mode # Not actually used in this version if bufsize < 0: @@ -908,7 +908,7 @@ if self._sock: self.flush() finally: - if self._sock and not isinstance(self._sock, _closedsocket): + if self._sock and isinstance(self._sock, _nonblocking_api_mixin): self._sock.reference_count -= 1 if not self._sock.reference_count: self._sock.close() Modified: trunk/jython/Lib/test/test_socket.py =================================================================== --- trunk/jython/Lib/test/test_socket.py 2007-12-24 19:01:06 UTC (rev 3862) +++ trunk/jython/Lib/test/test_socket.py 2007-12-24 19:36:03 UTC (rev 3863) @@ -15,6 +15,7 @@ import Queue import sys from weakref import proxy +from StringIO import StringIO PORT = 50007 HOST = 'localhost' @@ -943,6 +944,26 @@ def _testClosedAttr(self): self.assert_(not self.cli_file.closed) +class PrivateFileObjectTestCase(unittest.TestCase): + + """Test usage of socket._fileobject with an arbitrary socket-like + object. + + E.g. urllib2 wraps an httplib.HTTPResponse object with _fileobject. + """ + + def setUp(self): + self.socket_like = StringIO() + self.socket_like.recv = self.socket_like.read + self.socket_like.sendall = self.socket_like.write + + def testPrivateFileObject(self): + fileobject = socket._fileobject(self.socket_like, 'rb') + fileobject.write('hello jython') + fileobject.flush() + self.socket_like.seek(0) + self.assertEqual(fileobject.read(), 'hello jython') + class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): """Repeat the tests from FileObjectClassTestCase with bufsize==0. @@ -1199,6 +1220,7 @@ UDPFileObjectClassOpenCloseTests, FileAndDupOpenCloseTests, FileObjectClassTestCase, + PrivateFileObjectTestCase, UnbufferedFileObjectClassTestCase, LineBufferedFileObjectClassTestCase, SmallBufferedFileObjectClassTestCase This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <pj...@us...> - 2008-01-10 04:18:41
|
Revision: 4014 http://jython.svn.sourceforge.net/jython/?rev=4014&view=rev Author: pjenvey Date: 2008-01-09 20:18:33 -0800 (Wed, 09 Jan 2008) Log Message: ----------- o implement lstat by cleverly comparing file's absolute paths vs their canonical paths. lstat can't determine anything about links except the fact that they are links o utilize lstat for os.path.islink Modified Paths: -------------- trunk/jython/Lib/javaos.py trunk/jython/Lib/javapath.py Modified: trunk/jython/Lib/javaos.py =================================================================== --- trunk/jython/Lib/javaos.py 2008-01-10 01:33:14 UTC (rev 4013) +++ trunk/jython/Lib/javaos.py 2008-01-10 04:18:33 UTC (rev 4014) @@ -291,6 +291,34 @@ mode = mode | _stat.S_IWRITE return stat_result((mode, 0, 0, 0, 0, 0, size, mtime, mtime, 0)) +def lstat(path): + """lstat(path) -> stat result + + Like stat(path), but do not follow symbolic links. + """ + f = File(sys.getPath(path)) + abs_parent = f.getAbsoluteFile().getParentFile() + can_parent = abs_parent.getCanonicalFile() + + if can_parent.getAbsolutePath() == abs_parent.getAbsolutePath(): + # The parent directory's absolute path is canonical.. + if f.getAbsolutePath() != f.getCanonicalPath(): + # but the file's absolute and paths differ (a link) + return stat_result((_stat.S_IFLNK, 0, 0, 0, 0, 0, 0, 0, 0, 0)) + + # The parent directory's path is not canonical (one of the parent + # directories is a symlink). Build a new path with the parent's + # canonical path and compare the files + f = File(_path.join(can_parent.getAbsolutePath(), f.getName())) + if f.getAbsolutePath() != f.getCanonicalPath(): + return stat_result((_stat.S_IFLNK, 0, 0, 0, 0, 0, 0, 0, 0, 0)) + + # Not a link, only now can we determine if it exists (because + # File.exists() returns False for dead links) + if not f.exists(): + raise OSError(0, 'No such file or directory', path) + return stat(path) + def utime(path, times): """utime(path, (atime, mtime)) utime(path, None) Modified: trunk/jython/Lib/javapath.py =================================================================== --- trunk/jython/Lib/javapath.py 2008-01-10 01:33:14 UTC (rev 4013) +++ trunk/jython/Lib/javapath.py 2008-01-10 04:18:33 UTC (rev 4014) @@ -6,7 +6,6 @@ """ # Incompletely implemented: -# islink -- How? # ismount -- How? # normcase -- How? @@ -14,6 +13,7 @@ # sameopenfile -- Java doesn't have fstat nor file descriptors? # samestat -- How? +import stat import sys from java.io import File import java.io.IOException @@ -143,13 +143,13 @@ return prefix def islink(path): - """Test whether a path is a symbolic link. + """Test whether a path is a symbolic link""" + try: + st = os.lstat(path) + except (os.error, AttributeError): + return False + return stat.S_ISLNK(st.st_mode) - XXX This incorrectly always returns false under JDK. - - """ - return 0 - def samefile(path, path2): """Test whether two pathnames reference the same actual file""" path = _tostr(path, "samefile") This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <cg...@us...> - 2008-01-14 08:46:23
|
Revision: 4024 http://jython.svn.sourceforge.net/jython/?rev=4024&view=rev Author: cgroves Date: 2008-01-14 00:46:22 -0800 (Mon, 14 Jan 2008) Log Message: ----------- Don't attempt to select a channel for OP_WRITE if it doesn't allow it. Server sockets only allow OP_ACCEPT Modified Paths: -------------- trunk/jython/Lib/select.py trunk/jython/Lib/test/test_select_new.py Modified: trunk/jython/Lib/select.py =================================================================== --- trunk/jython/Lib/select.py 2008-01-13 09:20:13 UTC (rev 4023) +++ trunk/jython/Lib/select.py 2008-01-14 08:46:22 UTC (rev 4024) @@ -73,7 +73,8 @@ else: jmask = OP_READ if mask & POLLOUT: - jmask |= OP_WRITE + if channel.validOps() & OP_WRITE: + jmask |= OP_WRITE if channel.validOps() & OP_CONNECT: jmask |= OP_CONNECT selectionkey = channel.register(self.selector, jmask) Modified: trunk/jython/Lib/test/test_select_new.py =================================================================== --- trunk/jython/Lib/test/test_select_new.py 2008-01-13 09:20:13 UTC (rev 4023) +++ trunk/jython/Lib/test/test_select_new.py 2008-01-14 08:46:22 UTC (rev 4024) @@ -42,7 +42,7 @@ pass def select_acceptable(self): - return select.select([self.server_socket], [], [], SELECT_TIMEOUT)[0] + return select.select([self.server_socket], [self.server_socket], [], SELECT_TIMEOUT)[0] def verify_acceptable(self): assert self.select_acceptable(), "Server socket should be acceptable" This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <am...@us...> - 2008-01-14 19:52:15
|
Revision: 4028 http://jython.svn.sourceforge.net/jython/?rev=4028&view=rev Author: amak Date: 2008-01-14 11:52:12 -0800 (Mon, 14 Jan 2008) Log Message: ----------- Modified the implementation of UDP sockets so that they use the java.nio APIs as much as possible. This was necessary because it was causing some legal code sequences to hang (CF bug #1782548). However, there is still a necessary dichotomy in the use of java.nio vs. java.net, because the java.nio APIs do not support timeouts. Therefore, when a UDP socket is in timeout mode, the java.net DatagramSocket API is still used. Otherwise, the java.nio DatagramChannel API is used. Unit tests updated accordingly. Also replaced some tabs with spaces in test_socket.py Modified Paths: -------------- trunk/jython/Lib/socket.py trunk/jython/Lib/test/test_socket.py Modified: trunk/jython/Lib/socket.py =================================================================== --- trunk/jython/Lib/socket.py 2008-01-14 18:40:55 UTC (rev 4027) +++ trunk/jython/Lib/socket.py 2008-01-14 19:52:12 UTC (rev 4028) @@ -265,7 +265,7 @@ # In timeout mode now new_cli_sock = self.jsocket.accept() return _client_socket_impl(new_cli_sock) - + class _datagram_socket_impl(_nio_impl): def __init__(self, port=None, address=None, reuse_addr=0): @@ -286,12 +286,86 @@ def finish_connect(self): return self.jchannel.finishConnect() - def receive(self, packet): - self.jsocket.receive(packet) + def disconnect(self): + """ + Disconnect the datagram socket. + cpython appears not to have this operation + """ + self.jchannel.disconnect() - def send(self, packet): + def _do_send_net(self, byte_array, socket_address, flags): + # Need two separate implementations because the java.nio APIs do not support timeouts + num_bytes = len(byte_array) + if socket_address: + packet = java.net.DatagramPacket(byte_array, num_bytes, socket_address) + else: + packet = java.net.DatagramPacket(byte_array, num_bytes) self.jsocket.send(packet) + return num_bytes + def _do_send_nio(self, byte_array, socket_address, flags): + byte_buf = java.nio.ByteBuffer.wrap(byte_array) + bytes_sent = self.jchannel.send(byte_buf, socket_address) + return bytes_sent + + def sendto(self, byte_array, address, flags): + host, port = _unpack_address_tuple(address) + socket_address = java.net.InetSocketAddress(host, port) + if self.mode == MODE_TIMEOUT: + return self._do_send_net(byte_array, socket_address, flags) + else: + return self._do_send_nio(byte_array, socket_address, flags) + + def send(self, byte_array, flags): + if self.mode == MODE_TIMEOUT: + return self._do_send_net(byte_array, None, flags) + else: + return self._do_send_nio(byte_array, None, flags) + + def _do_receive_net(self, return_source_address, num_bytes, flags): + byte_array = jarray.zeros(num_bytes, 'b') + packet = java.net.DatagramPacket(byte_array, num_bytes) + self.jsocket.receive(packet) + bytes_rcvd = packet.getLength() + if bytes_rcvd < num_bytes: + byte_array = byte_array[:bytes_rcvd] + return_data = byte_array.tostring() + if return_source_address: + host = None + if packet.getAddress(): + host = packet.getAddress().getHostName() + port = packet.getPort() + return return_data, (host, port) + else: + return return_data + + def _do_receive_nio(self, return_source_address, num_bytes, flags): + byte_array = jarray.zeros(num_bytes, 'b') + byte_buf = java.nio.ByteBuffer.wrap(byte_array) + source_address = self.jchannel.receive(byte_buf) + byte_buf.flip() ; bytes_read = byte_buf.remaining() + if source_address is None and not self.jchannel.isBlocking(): + raise would_block_error() + if bytes_read < num_bytes: + byte_array = byte_array[:bytes_read] + return_data = byte_array.tostring() + if return_source_address: + return return_data, (source_address.getHostName(), source_address.getPort()) + else: + return return_data + + def recvfrom(self, num_bytes, flags): + if self.mode == MODE_TIMEOUT: + return self._do_receive_net(1, num_bytes, flags) + else: + return self._do_receive_nio(1, num_bytes, flags) + + def recv(self, num_bytes, flags): + if self.mode == MODE_TIMEOUT: + return self._do_receive_net(0, num_bytes, flags) + else: + return self._do_receive_nio(0, num_bytes, flags) + __all__ = [ 'AF_INET', 'SO_REUSEADDR', 'SOCK_DGRAM', 'SOCK_RAW', 'SOCK_RDM', 'SOCK_SEQPACKET', 'SOCK_STREAM', 'SOL_SOCKET', 'SocketType', 'error', 'herror', 'gaierror', 'timeout', @@ -685,49 +759,29 @@ flags, addr = 0, p1 else: flags, addr = 0, p2 - n = len(data) if not self.sock_impl: self.sock_impl = _datagram_socket_impl() - host, port = addr - bytes = java.lang.String(data).getBytes('iso-8859-1') - a = java.net.InetAddress.getByName(host) - packet = java.net.DatagramPacket(bytes, n, a, port) - self.sock_impl.send(packet) - return n + byte_array = java.lang.String(data).getBytes('iso-8859-1') + result = self.sock_impl.sendto(byte_array, addr, flags) + return result except java.lang.Exception, jlx: raise _map_exception(jlx) - def send(self, data): + def send(self, data, flags=None): if not self.addr: raise error(errno.ENOTCONN, "Socket is not connected") - return self.sendto(data, self.addr) + byte_array = java.lang.String(data).getBytes('iso-8859-1') + return self.sock_impl.send(byte_array, flags) - def recvfrom(self, n): + def recvfrom(self, num_bytes, flags=None): try: assert self.sock_impl - bytes = jarray.zeros(n, 'b') - packet = java.net.DatagramPacket(bytes, n) - self.sock_impl.receive(packet) - host = None - if packet.getAddress(): - host = packet.getAddress().getHostName() - port = packet.getPort() - m = packet.getLength() - if m < n: - bytes = bytes[:m] - return bytes.tostring(), (host, port) + return self.sock_impl.recvfrom(num_bytes, flags) except java.lang.Exception, jlx: raise _map_exception(jlx) - def recv(self, n): + def recv(self, num_bytes, flags=None): try: - assert self.sock_impl - bytes = jarray.zeros(n, 'b') - packet = java.net.DatagramPacket(bytes, n) - self.sock_impl.receive(packet) - m = packet.getLength() - if m < n: - bytes = bytes[:m] - return bytes.tostring() + return self.sock_impl.recv(num_bytes, flags) except java.lang.Exception, jlx: raise _map_exception(jlx) Modified: trunk/jython/Lib/test/test_socket.py =================================================================== --- trunk/jython/Lib/test/test_socket.py 2008-01-14 18:40:55 UTC (rev 4027) +++ trunk/jython/Lib/test/test_socket.py 2008-01-14 19:52:12 UTC (rev 4028) @@ -602,13 +602,24 @@ ThreadedUDPSocketTest.__init__(self, methodName=methodName) def testSendtoAndRecv(self): - # Testing sendto() and Recv() over UDP + # Testing sendto() and recv() over UDP msg = self.serv.recv(len(MSG)) self.assertEqual(msg, MSG) def _testSendtoAndRecv(self): self.cli.sendto(MSG, 0, (HOST, PORT)) + def testSendtoAndRecvTimeoutMode(self): + # Need to test again in timeout mode, which follows + # a different code path + self.serv.settimeout(10) + msg = self.serv.recv(len(MSG)) + self.assertEqual(msg, MSG) + + def _testSendtoAndRecvTimeoutMode(self): + self.cli.settimeout(10) + self.cli.sendto(MSG, 0, (HOST, PORT)) + def testRecvFrom(self): # Testing recvfrom() over UDP msg, addr = self.serv.recvfrom(len(MSG)) @@ -617,6 +628,17 @@ def _testRecvFrom(self): self.cli.sendto(MSG, 0, (HOST, PORT)) + def testRecvFromTimeoutMode(self): + # Need to test again in timeout mode, which follows + # a different code path + self.serv.settimeout(10) + msg, addr = self.serv.recvfrom(len(MSG)) + self.assertEqual(msg, MSG) + + def _testRecvFromTimeoutMode(self): + self.cli.settimeout(10) + self.cli.sendto(MSG, 0, (HOST, PORT)) + def testSendtoEightBitSafe(self): # This test is necessary because java only supports signed bytes msg = self.serv.recv(len(EIGHT_BIT_MSG)) @@ -625,6 +647,17 @@ def _testSendtoEightBitSafe(self): self.cli.sendto(EIGHT_BIT_MSG, 0, (HOST, PORT)) + def testSendtoEightBitSafeTimeoutMode(self): + # Need to test again in timeout mode, which follows + # a different code path + self.serv.settimeout(10) + msg = self.serv.recv(len(EIGHT_BIT_MSG)) + self.assertEqual(msg, EIGHT_BIT_MSG) + + def _testSendtoEightBitSafeTimeoutMode(self): + self.cli.settimeout(10) + self.cli.sendto(EIGHT_BIT_MSG, 0, (HOST, PORT)) + class BasicSocketPairTest(SocketPairTest): def __init__(self, methodName='runTest'): @@ -1169,39 +1202,39 @@ class TestAddressParameters: - def testBindNonTupleEndpointRaisesTypeError(self): - try: - self.socket.bind(HOST, PORT) - except TypeError: - pass - else: - self.fail("Illegal non-tuple bind address did not raise TypeError") + def testBindNonTupleEndpointRaisesTypeError(self): + try: + self.socket.bind(HOST, PORT) + except TypeError: + pass + else: + self.fail("Illegal non-tuple bind address did not raise TypeError") - def testConnectNonTupleEndpointRaisesTypeError(self): - try: - self.socket.connect(HOST, PORT) - except TypeError: - pass - else: - self.fail("Illegal non-tuple connect address did not raise TypeError") + def testConnectNonTupleEndpointRaisesTypeError(self): + try: + self.socket.connect(HOST, PORT) + except TypeError: + pass + else: + self.fail("Illegal non-tuple connect address did not raise TypeError") - def testConnectExNonTupleEndpointRaisesTypeError(self): - try: - self.socket.connect_ex(HOST, PORT) - except TypeError: - pass - else: - self.fail("Illegal non-tuple connect address did not raise TypeError") + def testConnectExNonTupleEndpointRaisesTypeError(self): + try: + self.socket.connect_ex(HOST, PORT) + except TypeError: + pass + else: + self.fail("Illegal non-tuple connect address did not raise TypeError") class TestTCPAddressParameters(unittest.TestCase, TestAddressParameters): - def setUp(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + def setUp(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) class TestUDPAddressParameters(unittest.TestCase, TestAddressParameters): - def setUp(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + def setUp(self): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) def test_main(): tests = [ This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <pj...@us...> - 2008-01-18 00:03:15
|
Revision: 4052 http://jython.svn.sourceforge.net/jython/?rev=4052&view=rev Author: pjenvey Date: 2008-01-17 16:03:05 -0800 (Thu, 17 Jan 2008) Log Message: ----------- from http://svn.python.org/projects/python/branches/release23-maint/Lib@60030 Added Paths: ----------- trunk/jython/Lib/inspect.py trunk/jython/Lib/test/test_inspect.py Added: trunk/jython/Lib/inspect.py =================================================================== --- trunk/jython/Lib/inspect.py (rev 0) +++ trunk/jython/Lib/inspect.py 2008-01-18 00:03:05 UTC (rev 4052) @@ -0,0 +1,809 @@ +# -*- coding: iso-8859-1 -*- +"""Get useful information from live Python objects. + +This module encapsulates the interface provided by the internal special +attributes (func_*, co_*, im_*, tb_*, etc.) in a friendlier fashion. +It also provides some help for examining source code and class layout. + +Here are some of the useful functions provided by this module: + + ismodule(), isclass(), ismethod(), isfunction(), istraceback(), + isframe(), iscode(), isbuiltin(), isroutine() - check object types + getmembers() - get members of an object that satisfy a given condition + + getfile(), getsourcefile(), getsource() - find an object's source code + getdoc(), getcomments() - get documentation on an object + getmodule() - determine the module that an object came from + getclasstree() - arrange classes so as to represent their hierarchy + + getargspec(), getargvalues() - get info about function arguments + formatargspec(), formatargvalues() - format an argument spec + getouterframes(), getinnerframes() - get info about frames + currentframe() - get the current stack frame + stack(), trace() - get info about frames on the stack or in a traceback +""" + +# This module is in the public domain. No warranties. + +__author__ = 'Ka-Ping Yee <pi...@lf...>' +__date__ = '1 Jan 2001' + +import sys, os, types, string, re, dis, imp, tokenize, linecache + +# ----------------------------------------------------------- type-checking +def ismodule(object): + """Return true if the object is a module. + + Module objects provide these attributes: + __doc__ documentation string + __file__ filename (missing for built-in modules)""" + return isinstance(object, types.ModuleType) + +def isclass(object): + """Return true if the object is a class. + + Class objects provide these attributes: + __doc__ documentation string + __module__ name of module in which this class was defined""" + return isinstance(object, types.ClassType) or hasattr(object, '__bases__') + +def ismethod(object): + """Return true if the object is an instance method. + + Instance method objects provide these attributes: + __doc__ documentation string + __name__ name with which this method was defined + im_class class object in which this method belongs + im_func function object containing implementation of method + im_self instance to which this method is bound, or None""" + return isinstance(object, types.MethodType) + +def ismethoddescriptor(object): + """Return true if the object is a method descriptor. + + But not if ismethod() or isclass() or isfunction() are true. + + This is new in Python 2.2, and, for example, is true of int.__add__. + An object passing this test has a __get__ attribute but not a __set__ + attribute, but beyond that the set of attributes varies. __name__ is + usually sensible, and __doc__ often is. + + Methods implemented via descriptors that also pass one of the other + tests return false from the ismethoddescriptor() test, simply because + the other tests promise more -- you can, e.g., count on having the + im_func attribute (etc) when an object passes ismethod().""" + return (hasattr(object, "__get__") + and not hasattr(object, "__set__") # else it's a data descriptor + and not ismethod(object) # mutual exclusion + and not isfunction(object) + and not isclass(object)) + +def isdatadescriptor(object): + """Return true if the object is a data descriptor. + + Data descriptors have both a __get__ and a __set__ attribute. Examples are + properties (defined in Python) and getsets and members (defined in C). + Typically, data descriptors will also have __name__ and __doc__ attributes + (properties, getsets, and members have both of these attributes), but this + is not guaranteed.""" + return (hasattr(object, "__set__") and hasattr(object, "__get__")) + +def isfunction(object): + """Return true if the object is a user-defined function. + + Function objects provide these attributes: + __doc__ documentation string + __name__ name with which this function was defined + func_code code object containing compiled function bytecode + func_defaults tuple of any default values for arguments + func_doc (same as __doc__) + func_globals global namespace in which this function was defined + func_name (same as __name__)""" + return isinstance(object, types.FunctionType) + +def istraceback(object): + """Return true if the object is a traceback. + + Traceback objects provide these attributes: + tb_frame frame object at this level + tb_lasti index of last attempted instruction in bytecode + tb_lineno current line number in Python source code + tb_next next inner traceback object (called by this level)""" + return isinstance(object, types.TracebackType) + +def isframe(object): + """Return true if the object is a frame object. + + Frame objects provide these attributes: + f_back next outer frame object (this frame's caller) + f_builtins built-in namespace seen by this frame + f_code code object being executed in this frame + f_exc_traceback traceback if raised in this frame, or None + f_exc_type exception type if raised in this frame, or None + f_exc_value exception value if raised in this frame, or None + f_globals global namespace seen by this frame + f_lasti index of last attempted instruction in bytecode + f_lineno current line number in Python source code + f_locals local namespace seen by this frame + f_restricted 0 or 1 if frame is in restricted execution mode + f_trace tracing function for this frame, or None""" + return isinstance(object, types.FrameType) + +def iscode(object): + """Return true if the object is a code object. + + Code objects provide these attributes: + co_argcount number of arguments (not including * or ** args) + co_code string of raw compiled bytecode + co_consts tuple of constants used in the bytecode + co_filename name of file in which this code object was created + co_firstlineno number of first line in Python source code + co_flags bitmap: 1=optimized | 2=newlocals | 4=*arg | 8=**arg + co_lnotab encoded mapping of line numbers to bytecode indices + co_name name with which this code object was defined + co_names tuple of names of local variables + co_nlocals number of local variables + co_stacksize virtual machine stack space required + co_varnames tuple of names of arguments and local variables""" + return isinstance(object, types.CodeType) + +def isbuiltin(object): + """Return true if the object is a built-in function or method. + + Built-in functions and methods provide these attributes: + __doc__ documentation string + __name__ original name of this function or method + __self__ instance to which a method is bound, or None""" + return isinstance(object, types.BuiltinFunctionType) + +def isroutine(object): + """Return true if the object is any kind of function or method.""" + return (isbuiltin(object) + or isfunction(object) + or ismethod(object) + or ismethoddescriptor(object)) + +def getmembers(object, predicate=None): + """Return all members of an object as (name, value) pairs sorted by name. + Optionally, only return members that satisfy a given predicate.""" + results = [] + for key in dir(object): + value = getattr(object, key) + if not predicate or predicate(value): + results.append((key, value)) + results.sort() + return results + +def classify_class_attrs(cls): + """Return list of attribute-descriptor tuples. + + For each name in dir(cls), the return list contains a 4-tuple + with these elements: + + 0. The name (a string). + + 1. The kind of attribute this is, one of these strings: + 'class method' created via classmethod() + 'static method' created via staticmethod() + 'property' created via property() + 'method' any other flavor of method + 'data' not a method + + 2. The class which defined this attribute (a class). + + 3. The object as obtained directly from the defining class's + __dict__, not via getattr. This is especially important for + data attributes: C.data is just a data object, but + C.__dict__['data'] may be a data descriptor with additional + info, like a __doc__ string. + """ + + mro = getmro(cls) + names = dir(cls) + result = [] + for name in names: + # Get the object associated with the name. + # Getting an obj from the __dict__ sometimes reveals more than + # using getattr. Static and class methods are dramatic examples. + if name in cls.__dict__: + obj = cls.__dict__[name] + else: + obj = getattr(cls, name) + + # Figure out where it was defined. + homecls = getattr(obj, "__objclass__", None) + if homecls is None: + # search the dicts. + for base in mro: + if name in base.__dict__: + homecls = base + break + + # Get the object again, in order to get it from the defining + # __dict__ instead of via getattr (if possible). + if homecls is not None and name in homecls.__dict__: + obj = homecls.__dict__[name] + + # Also get the object via getattr. + obj_via_getattr = getattr(cls, name) + + # Classify the object. + if isinstance(obj, staticmethod): + kind = "static method" + elif isinstance(obj, classmethod): + kind = "class method" + elif isinstance(obj, property): + kind = "property" + elif (ismethod(obj_via_getattr) or + ismethoddescriptor(obj_via_getattr)): + kind = "method" + else: + kind = "data" + + result.append((name, kind, homecls, obj)) + + return result + +# ----------------------------------------------------------- class helpers +def _searchbases(cls, accum): + # Simulate the "classic class" search order. + if cls in accum: + return + accum.append(cls) + for base in cls.__bases__: + _searchbases(base, accum) + +def getmro(cls): + "Return tuple of base classes (including cls) in method resolution order." + if hasattr(cls, "__mro__"): + return cls.__mro__ + else: + result = [] + _searchbases(cls, result) + return tuple(result) + +# -------------------------------------------------- source code extraction +def indentsize(line): + """Return the indent size, in spaces, at the start of a line of text.""" + expline = string.expandtabs(line) + return len(expline) - len(string.lstrip(expline)) + +def getdoc(object): + """Get the documentation string for an object. + + All tabs are expanded to spaces. To clean up docstrings that are + indented to line up with blocks of code, any whitespace than can be + uniformly removed from the second line onwards is removed.""" + try: + doc = object.__doc__ + except AttributeError: + return None + if not isinstance(doc, types.StringTypes): + return None + try: + lines = string.split(string.expandtabs(doc), '\n') + except UnicodeError: + return None + else: + # Find minimum indentation of any non-blank lines after first line. + margin = sys.maxint + for line in lines[1:]: + content = len(string.lstrip(line)) + if content: + indent = len(line) - content + margin = min(margin, indent) + # Remove indentation. + if lines: + lines[0] = lines[0].lstrip() + if margin < sys.maxint: + for i in range(1, len(lines)): lines[i] = lines[i][margin:] + # Remove any trailing or leading blank lines. + while lines and not lines[-1]: + lines.pop() + while lines and not lines[0]: + lines.pop(0) + return string.join(lines, '\n') + +def getfile(object): + """Work out which source or compiled file an object was defined in.""" + if ismodule(object): + if hasattr(object, '__file__'): + return object.__file__ + raise TypeError('arg is a built-in module') + if isclass(object): + object = sys.modules.get(object.__module__) + if hasattr(object, '__file__'): + return object.__file__ + raise TypeError('arg is a built-in class') + if ismethod(object): + object = object.im_func + if isfunction(object): + object = object.func_code + if istraceback(object): + object = object.tb_frame + if isframe(object): + object = object.f_code + if iscode(object): + return object.co_filename + raise TypeError('arg is not a module, class, method, ' + 'function, traceback, frame, or code object') + +def getmoduleinfo(path): + """Get the module name, suffix, mode, and module type for a given file.""" + filename = os.path.basename(path) + suffixes = map(lambda (suffix, mode, mtype): + (-len(suffix), suffix, mode, mtype), imp.get_suffixes()) + suffixes.sort() # try longest suffixes first, in case they overlap + for neglen, suffix, mode, mtype in suffixes: + if filename[neglen:] == suffix: + return filename[:neglen], suffix, mode, mtype + +def getmodulename(path): + """Return the module name for a given file, or None.""" + info = getmoduleinfo(path) + if info: return info[0] + +def getsourcefile(object): + """Return the Python source file an object was defined in, if it exists.""" + filename = getfile(object) + if string.lower(filename[-4:]) in ['.pyc', '.pyo']: + filename = filename[:-4] + '.py' + for suffix, mode, kind in imp.get_suffixes(): + if 'b' in mode and string.lower(filename[-len(suffix):]) == suffix: + # Looks like a binary file. We want to only return a text file. + return None + if os.path.exists(filename): + return filename + +def getabsfile(object): + """Return an absolute path to the source or compiled file for an object. + + The idea is for each object to have a unique origin, so this routine + normalizes the result as much as possible.""" + return os.path.normcase( + os.path.abspath(getsourcefile(object) or getfile(object))) + +modulesbyfile = {} + +def getmodule(object): + """Return the module an object was defined in, or None if not found.""" + if ismodule(object): + return object + if isclass(object): + return sys.modules.get(object.__module__) + try: + file = getabsfile(object) + except TypeError: + return None + if file in modulesbyfile: + return sys.modules.get(modulesbyfile[file]) + for module in sys.modules.values(): + if hasattr(module, '__file__'): + modulesbyfile[getabsfile(module)] = module.__name__ + if file in modulesbyfile: + return sys.modules.get(modulesbyfile[file]) + main = sys.modules['__main__'] + if not hasattr(object, '__name__'): + return None + if hasattr(main, object.__name__): + mainobject = getattr(main, object.__name__) + if mainobject is object: + return main + builtin = sys.modules['__builtin__'] + if hasattr(builtin, object.__name__): + builtinobject = getattr(builtin, object.__name__) + if builtinobject is object: + return builtin + +def findsource(object): + """Return the entire source file and starting line number for an object. + + The argument may be a module, class, method, function, traceback, frame, + or code object. The source code is returned as a list of all the lines + in the file and the line number indexes a line in that list. An IOError + is raised if the source code cannot be retrieved.""" + file = getsourcefile(object) or getfile(object) + lines = linecache.getlines(file) + if not lines: + raise IOError('could not get source code') + + if ismodule(object): + return lines, 0 + + if isclass(object): + name = object.__name__ + pat = re.compile(r'^\s*class\s*' + name + r'\b') + for i in range(len(lines)): + if pat.match(lines[i]): return lines, i + else: + raise IOError('could not find class definition') + + if ismethod(object): + object = object.im_func + if isfunction(object): + object = object.func_code + if istraceback(object): + object = object.tb_frame + if isframe(object): + object = object.f_code + if iscode(object): + if not hasattr(object, 'co_firstlineno'): + raise IOError('could not find function definition') + lnum = object.co_firstlineno - 1 + pat = re.compile(r'^(\s*def\s)|(.*\slambda(:|\s))') + while lnum > 0: + if pat.match(lines[lnum]): break + lnum = lnum - 1 + return lines, lnum + raise IOError('could not find code object') + +def getcomments(object): + """Get lines of comments immediately preceding an object's source code. + + Returns None when source can't be found. + """ + try: + lines, lnum = findsource(object) + except (IOError, TypeError): + return None + + if ismodule(object): + # Look for a comment block at the top of the file. + start = 0 + if lines and lines[0][:2] == '#!': start = 1 + while start < len(lines) and string.strip(lines[start]) in ['', '#']: + start = start + 1 + if start < len(lines) and lines[start][:1] == '#': + comments = [] + end = start + while end < len(lines) and lines[end][:1] == '#': + comments.append(string.expandtabs(lines[end])) + end = end + 1 + return string.join(comments, '') + + # Look for a preceding block of comments at the same indentation. + elif lnum > 0: + indent = indentsize(lines[lnum]) + end = lnum - 1 + if end >= 0 and string.lstrip(lines[end])[:1] == '#' and \ + indentsize(lines[end]) == indent: + comments = [string.lstrip(string.expandtabs(lines[end]))] + if end > 0: + end = end - 1 + comment = string.lstrip(string.expandtabs(lines[end])) + while comment[:1] == '#' and indentsize(lines[end]) == indent: + comments[:0] = [comment] + end = end - 1 + if end < 0: break + comment = string.lstrip(string.expandtabs(lines[end])) + while comments and string.strip(comments[0]) == '#': + comments[:1] = [] + while comments and string.strip(comments[-1]) == '#': + comments[-1:] = [] + return string.join(comments, '') + +class ListReader: + """Provide a readline() method to return lines from a list of strings.""" + def __init__(self, lines): + self.lines = lines + self.index = 0 + + def readline(self): + i = self.index + if i < len(self.lines): + self.index = i + 1 + return self.lines[i] + else: return '' + +class EndOfBlock(Exception): pass + +class BlockFinder: + """Provide a tokeneater() method to detect the end of a code block.""" + def __init__(self): + self.indent = 0 + self.started = 0 + self.last = 0 + + def tokeneater(self, type, token, (srow, scol), (erow, ecol), line): + if not self.started: + if type == tokenize.NAME: self.started = 1 + elif type == tokenize.NEWLINE: + self.last = srow + elif type == tokenize.INDENT: + self.indent = self.indent + 1 + elif type == tokenize.DEDENT: + self.indent = self.indent - 1 + if self.indent == 0: + raise EndOfBlock, self.last + elif type == tokenize.NAME and scol == 0: + raise EndOfBlock, self.last + +def getblock(lines): + """Extract the block of code at the top of the given list of lines.""" + try: + tokenize.tokenize(ListReader(lines).readline, BlockFinder().tokeneater) + except EndOfBlock, eob: + return lines[:eob.args[0]] + # Fooling the indent/dedent logic implies a one-line definition + return lines[:1] + +def getsourcelines(object): + """Return a list of source lines and starting line number for an object. + + The argument may be a module, class, method, function, traceback, frame, + or code object. The source code is returned as a list of the lines + corresponding to the object and the line number indicates where in the + original source file the first line of code was found. An IOError is + raised if the source code cannot be retrieved.""" + lines, lnum = findsource(object) + + if ismodule(object): return lines, 0 + else: return getblock(lines[lnum:]), lnum + 1 + +def getsource(object): + """Return the text of the source code for an object. + + The argument may be a module, class, method, function, traceback, frame, + or code object. The source code is returned as a single string. An + IOError is raised if the source code cannot be retrieved.""" + lines, lnum = getsourcelines(object) + return string.join(lines, '') + +# --------------------------------------------------- class tree extraction +def walktree(classes, children, parent): + """Recursive helper function for getclasstree().""" + results = [] + classes.sort(lambda a, b: cmp(a.__name__, b.__name__)) + for c in classes: + results.append((c, c.__bases__)) + if c in children: + results.append(walktree(children[c], children, c)) + return results + +def getclasstree(classes, unique=0): + """Arrange the given list of classes into a hierarchy of nested lists. + + Where a nested list appears, it contains classes derived from the class + whose entry immediately precedes the list. Each entry is a 2-tuple + containing a class and a tuple of its base classes. If the 'unique' + argument is true, exactly one entry appears in the returned structure + for each class in the given list. Otherwise, classes using multiple + inheritance and their descendants will appear multiple times.""" + children = {} + roots = [] + for c in classes: + if c.__bases__: + for parent in c.__bases__: + if not parent in children: + children[parent] = [] + children[parent].append(c) + if unique and parent in classes: break + elif c not in roots: + roots.append(c) + for parent in children: + if parent not in classes: + roots.append(parent) + return walktree(roots, children, None) + +# ------------------------------------------------ argument list extraction +# These constants are from Python's compile.h. +CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS = 1, 2, 4, 8 + +def getargs(co): + """Get information about the arguments accepted by a code object. + + Three things are returned: (args, varargs, varkw), where 'args' is + a list of argument names (possibly containing nested lists), and + 'varargs' and 'varkw' are the names of the * and ** arguments or None.""" + + if not iscode(co): + raise TypeError('arg is not a code object') + + code = co.co_code + nargs = co.co_argcount + names = co.co_varnames + args = list(names[:nargs]) + step = 0 + + # The following acrobatics are for anonymous (tuple) arguments. + for i in range(nargs): + if args[i][:1] in ['', '.']: + stack, remain, count = [], [], [] + while step < len(code): + op = ord(code[step]) + step = step + 1 + if op >= dis.HAVE_ARGUMENT: + opname = dis.opname[op] + value = ord(code[step]) + ord(code[step+1])*256 + step = step + 2 + if opname in ['UNPACK_TUPLE', 'UNPACK_SEQUENCE']: + remain.append(value) + count.append(value) + elif opname == 'STORE_FAST': + stack.append(names[value]) + + # Special case for sublists of length 1: def foo((bar)) + # doesn't generate the UNPACK_TUPLE bytecode, so if + # `remain` is empty here, we have such a sublist. + if not remain: + stack[0] = [stack[0]] + break + else: + remain[-1] = remain[-1] - 1 + while remain[-1] == 0: + remain.pop() + size = count.pop() + stack[-size:] = [stack[-size:]] + if not remain: break + remain[-1] = remain[-1] - 1 + if not remain: break + args[i] = stack[0] + + varargs = None + if co.co_flags & CO_VARARGS: + varargs = co.co_varnames[nargs] + nargs = nargs + 1 + varkw = None + if co.co_flags & CO_VARKEYWORDS: + varkw = co.co_varnames[nargs] + return args, varargs, varkw + +def getargspec(func): + """Get the names and default values of a function's arguments. + + A tuple of four things is returned: (args, varargs, varkw, defaults). + 'args' is a list of the argument names (it may contain nested lists). + 'varargs' and 'varkw' are the names of the * and ** arguments or None. + 'defaults' is an n-tuple of the default values of the last n arguments. + """ + + if ismethod(func): + func = func.im_func + if not isfunction(func): + raise TypeError('arg is not a Python function') + args, varargs, varkw = getargs(func.func_code) + return args, varargs, varkw, func.func_defaults + +def getargvalues(frame): + """Get information about arguments passed into a particular frame. + + A tuple of four things is returned: (args, varargs, varkw, locals). + 'args' is a list of the argument names (it may contain nested lists). + 'varargs' and 'varkw' are the names of the * and ** arguments or None. + 'locals' is the locals dictionary of the given frame.""" + args, varargs, varkw = getargs(frame.f_code) + return args, varargs, varkw, frame.f_locals + +def joinseq(seq): + if len(seq) == 1: + return '(' + seq[0] + ',)' + else: + return '(' + string.join(seq, ', ') + ')' + +def strseq(object, convert, join=joinseq): + """Recursively walk a sequence, stringifying each element.""" + if type(object) in [types.ListType, types.TupleType]: + return join(map(lambda o, c=convert, j=join: strseq(o, c, j), object)) + else: + return convert(object) + +def formatargspec(args, varargs=None, varkw=None, defaults=None, + formatarg=str, + formatvarargs=lambda name: '*' + name, + formatvarkw=lambda name: '**' + name, + formatvalue=lambda value: '=' + repr(value), + join=joinseq): + """Format an argument spec from the 4 values returned by getargspec. + + The first four arguments are (args, varargs, varkw, defaults). The + other four arguments are the corresponding optional formatting functions + that are called to turn names and values into strings. The ninth + argument is an optional function to format the sequence of arguments.""" + specs = [] + if defaults: + firstdefault = len(args) - len(defaults) + for i in range(len(args)): + spec = strseq(args[i], formatarg, join) + if defaults and i >= firstdefault: + spec = spec + formatvalue(defaults[i - firstdefault]) + specs.append(spec) + if varargs is not None: + specs.append(formatvarargs(varargs)) + if varkw is not None: + specs.append(formatvarkw(varkw)) + return '(' + string.join(specs, ', ') + ')' + +def formatargvalues(args, varargs, varkw, locals, + formatarg=str, + formatvarargs=lambda name: '*' + name, + formatvarkw=lambda name: '**' + name, + formatvalue=lambda value: '=' + repr(value), + join=joinseq): + """Format an argument spec from the 4 values returned by getargvalues. + + The first four arguments are (args, varargs, varkw, locals). The + next four arguments are the corresponding optional formatting functions + that are called to turn names and values into strings. The ninth + argument is an optional function to format the sequence of arguments.""" + def convert(name, locals=locals, + formatarg=formatarg, formatvalue=formatvalue): + return formatarg(name) + formatvalue(locals[name]) + specs = [] + for i in range(len(args)): + specs.append(strseq(args[i], convert, join)) + if varargs: + specs.append(formatvarargs(varargs) + formatvalue(locals[varargs])) + if varkw: + specs.append(formatvarkw(varkw) + formatvalue(locals[varkw])) + return '(' + string.join(specs, ', ') + ')' + +# -------------------------------------------------- stack frame extraction +def getframeinfo(frame, context=1): + """Get information about a frame or traceback object. + + A tuple of five things is returned: the filename, the line number of + the current line, the function name, a list of lines of context from + the source code, and the index of the current line within that list. + The optional second argument specifies the number of lines of context + to return, which are centered around the current line.""" + if istraceback(frame): + lineno = frame.tb_lineno + frame = frame.tb_frame + else: + lineno = frame.f_lineno + if not isframe(frame): + raise TypeError('arg is not a frame or traceback object') + + filename = getsourcefile(frame) or getfile(frame) + if context > 0: + start = lineno - 1 - context//2 + try: + lines, lnum = findsource(frame) + except IOError: + lines = index = None + else: + start = max(start, 1) + start = min(start, len(lines) - context) + lines = lines[start:start+context] + index = lineno - 1 - start + else: + lines = index = None + + return (filename, lineno, frame.f_code.co_name, lines, index) + +def getlineno(frame): + """Get the line number from a frame object, allowing for optimization.""" + # FrameType.f_lineno is now a descriptor that grovels co_lnotab + return frame.f_lineno + +def getouterframes(frame, context=1): + """Get a list of records for a frame and all higher (calling) frames. + + Each record contains a frame object, filename, line number, function + name, a list of lines of context, and index within the context.""" + framelist = [] + while frame: + framelist.append((frame,) + getframeinfo(frame, context)) + frame = frame.f_back + return framelist + +def getinnerframes(tb, context=1): + """Get a list of records for a traceback's frame and all lower frames. + + Each record contains a frame object, filename, line number, function + name, a list of lines of context, and index within the context.""" + framelist = [] + while tb: + framelist.append((tb.tb_frame,) + getframeinfo(tb, context)) + tb = tb.tb_next + return framelist + +currentframe = sys._getframe + +def stack(context=1): + """Return a list of records for the stack above the caller's frame.""" + return getouterframes(sys._getframe(1), context) + +def trace(context=1): + """Return a list of records for the stack below the current exception.""" + return getinnerframes(sys.exc_info()[2], context) Added: trunk/jython/Lib/test/test_inspect.py =================================================================== --- trunk/jython/Lib/test/test_inspect.py (rev 0) +++ trunk/jython/Lib/test/test_inspect.py 2008-01-18 00:03:05 UTC (rev 4052) @@ -0,0 +1,384 @@ +source = '''# line 1 +'A module docstring.' + +import sys, inspect +# line 5 + +# line 7 +def spam(a, b, c, d=3, (e, (f,))=(4, (5,)), *g, **h): + eggs(b + d, c + f) + +# line 11 +def eggs(x, y): + "A docstring." + global fr, st + fr = inspect.currentframe() + st = inspect.stack() + p = x + q = y / 0 + +# line 20 +class StupidGit: + """A longer, + + indented + + docstring.""" +# line 27 + + def abuse(self, a, b, c): + """Another + +\tdocstring + + containing + +\ttabs +\t + """ + self.argue(a, b, c) +# line 40 + def argue(self, a, b, c): + try: + spam(a, b, c) + except: + self.ex = sys.exc_info() + self.tr = inspect.trace() + +# line 48 +class MalodorousPervert(StupidGit): + pass + +class ParrotDroppings: + pass + +class FesteringGob(MalodorousPervert, ParrotDroppings): + pass +''' + +# Functions tested in this suite: +# ismodule, isclass, ismethod, isfunction, istraceback, isframe, iscode, +# isbuiltin, isroutine, getmembers, getdoc, getfile, getmodule, +# getsourcefile, getcomments, getsource, getclasstree, getargspec, +# getargvalues, formatargspec, formatargvalues, currentframe, stack, trace +# isdatadescriptor + +from test.test_support import TestFailed, TESTFN +import sys, imp, os, string + +def test(assertion, message, *args): + if not assertion: + raise TestFailed, message % args + +import inspect + +file = open(TESTFN, 'w') +file.write(source) +file.close() + +# Note that load_source creates file TESTFN+'c' or TESTFN+'o'. +mod = imp.load_source('testmod', TESTFN) +files_to_clean_up = [TESTFN, TESTFN + 'c', TESTFN + 'o'] + +def istest(func, exp): + obj = eval(exp) + test(func(obj), '%s(%s)' % (func.__name__, exp)) + for other in [inspect.isbuiltin, inspect.isclass, inspect.iscode, + inspect.isframe, inspect.isfunction, inspect.ismethod, + inspect.ismodule, inspect.istraceback]: + if other is not func: + test(not other(obj), 'not %s(%s)' % (other.__name__, exp)) + +git = mod.StupidGit() +try: + 1/0 +except: + tb = sys.exc_traceback + +istest(inspect.isbuiltin, 'sys.exit') +istest(inspect.isbuiltin, '[].append') +istest(inspect.isclass, 'mod.StupidGit') +istest(inspect.iscode, 'mod.spam.func_code') +istest(inspect.isframe, 'tb.tb_frame') +istest(inspect.isfunction, 'mod.spam') +istest(inspect.ismethod, 'mod.StupidGit.abuse') +istest(inspect.ismethod, 'git.argue') +istest(inspect.ismodule, 'mod') +istest(inspect.istraceback, 'tb') +import __builtin__ +istest(inspect.isdatadescriptor, '__builtin__.file.closed') +istest(inspect.isdatadescriptor, '__builtin__.file.softspace') +test(inspect.isroutine(mod.spam), 'isroutine(mod.spam)') +test(inspect.isroutine([].count), 'isroutine([].count)') + +classes = inspect.getmembers(mod, inspect.isclass) +test(classes == + [('FesteringGob', mod.FesteringGob), + ('MalodorousPervert', mod.MalodorousPervert), + ('ParrotDroppings', mod.ParrotDroppings), + ('StupidGit', mod.StupidGit)], 'class list') +tree = inspect.getclasstree(map(lambda x: x[1], classes), 1) +test(tree == + [(mod.ParrotDroppings, ()), + (mod.StupidGit, ()), + [(mod.MalodorousPervert, (mod.StupidGit,)), + [(mod.FesteringGob, (mod.MalodorousPervert, mod.ParrotDroppings)) + ] + ] + ], 'class tree') + +functions = inspect.getmembers(mod, inspect.isfunction) +test(functions == [('eggs', mod.eggs), ('spam', mod.spam)], 'function list') + +test(inspect.getdoc(mod) == 'A module docstring.', 'getdoc(mod)') +test(inspect.getcomments(mod) == '# line 1\n', 'getcomments(mod)') +test(inspect.getmodule(mod.StupidGit) == mod, 'getmodule(mod.StupidGit)') +test(inspect.getfile(mod.StupidGit) == TESTFN, 'getfile(mod.StupidGit)') +test(inspect.getsourcefile(mod.spam) == TESTFN, 'getsourcefile(mod.spam)') +test(inspect.getsourcefile(git.abuse) == TESTFN, 'getsourcefile(git.abuse)') + +def sourcerange(top, bottom): + lines = string.split(source, '\n') + return string.join(lines[top-1:bottom], '\n') + '\n' + +test(inspect.getsource(git.abuse) == sourcerange(29, 39), + 'getsource(git.abuse)') +test(inspect.getsource(mod.StupidGit) == sourcerange(21, 46), + 'getsource(mod.StupidGit)') +test(inspect.getdoc(mod.StupidGit) == + 'A longer,\n\nindented\n\ndocstring.', 'getdoc(mod.StupidGit)') +test(inspect.getdoc(git.abuse) == + 'Another\n\ndocstring\n\ncontaining\n\ntabs', 'getdoc(git.abuse)') +test(inspect.getcomments(mod.StupidGit) == '# line 20\n', + 'getcomments(mod.StupidGit)') + +git.abuse(7, 8, 9) + +istest(inspect.istraceback, 'git.ex[2]') +istest(inspect.isframe, 'mod.fr') + +test(len(git.tr) == 3, 'trace() length') +test(git.tr[0][1:] == (TESTFN, 43, 'argue', + [' spam(a, b, c)\n'], 0), + 'trace() row 2') +test(git.tr[1][1:] == (TESTFN, 9, 'spam', [' eggs(b + d, c + f)\n'], 0), + 'trace() row 2') +test(git.tr[2][1:] == (TESTFN, 18, 'eggs', [' q = y / 0\n'], 0), + 'trace() row 3') + +test(len(mod.st) >= 5, 'stack() length') +test(mod.st[0][1:] == + (TESTFN, 16, 'eggs', [' st = inspect.stack()\n'], 0), + 'stack() row 1') +test(mod.st[1][1:] == + (TESTFN, 9, 'spam', [' eggs(b + d, c + f)\n'], 0), + 'stack() row 2') +test(mod.st[2][1:] == + (TESTFN, 43, 'argue', [' spam(a, b, c)\n'], 0), + 'stack() row 3') +test(mod.st[3][1:] == + (TESTFN, 39, 'abuse', [' self.argue(a, b, c)\n'], 0), + 'stack() row 4') + +args, varargs, varkw, locals = inspect.getargvalues(mod.fr) +test(args == ['x', 'y'], 'mod.fr args') +test(varargs == None, 'mod.fr varargs') +test(varkw == None, 'mod.fr varkw') +test(locals == {'x': 11, 'p': 11, 'y': 14}, 'mod.fr locals') +test(inspect.formatargvalues(args, varargs, varkw, locals) == + '(x=11, y=14)', 'mod.fr formatted argvalues') + +args, varargs, varkw, locals = inspect.getargvalues(mod.fr.f_back) +test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.fr.f_back args') +test(varargs == 'g', 'mod.fr.f_back varargs') +test(varkw == 'h', 'mod.fr.f_back varkw') +test(inspect.formatargvalues(args, varargs, varkw, locals) == + '(a=7, b=8, c=9, d=3, (e=4, (f=5,)), *g=(), **h={})', + 'mod.fr.f_back formatted argvalues') + +for fname in files_to_clean_up: + try: + os.unlink(fname) + except: + pass + +# Test classic-class method resolution order. +class A: pass +class B(A): pass +class C(A): pass +class D(B, C): pass + +expected = (D, B, A, C) +got = inspect.getmro(D) +test(expected == got, "expected %r mro, got %r", expected, got) + +# The same w/ new-class MRO. +class A(object): pass +class B(A): pass +class C(A): pass +class D(B, C): pass + +expected = (D, B, C, A, object) +got = inspect.getmro(D) +test(expected == got, "expected %r mro, got %r", expected, got) + +# Test classify_class_attrs. +def attrs_wo_objs(cls): + return [t[:3] for t in inspect.classify_class_attrs(cls)] + +class A: + def s(): pass + s = staticmethod(s) + + def c(cls): pass + c = classmethod(c) + + def getp(self): pass + p = property(getp) + + def m(self): pass + + def m1(self): pass + + datablob = '1' + +attrs = attrs_wo_objs(A) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', A) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') + +class B(A): + def m(self): pass + +attrs = attrs_wo_objs(B) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', B) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') + + +class C(A): + def m(self): pass + def c(self): pass + +attrs = attrs_wo_objs(C) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'method', C) in attrs, 'missing plain method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', C) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') + +class D(B, C): + def m1(self): pass + +attrs = attrs_wo_objs(D) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', B) in attrs, 'missing plain method') +test(('m1', 'method', D) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') + +# Repeat all that, but w/ new-style classes. + +class A(object): + + def s(): pass + s = staticmethod(s) + + def c(cls): pass + c = classmethod(c) + + def getp(self): pass + p = property(getp) + + def m(self): pass + + def m1(self): pass + + datablob = '1' + +attrs = attrs_wo_objs(A) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', A) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') + +class B(A): + + def m(self): pass + +attrs = attrs_wo_objs(B) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'class method', A) in attrs, 'missing class method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', B) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') + + +class C(A): + + def m(self): pass + def c(self): pass + +attrs = attrs_wo_objs(C) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'method', C) in attrs, 'missing plain method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', C) in attrs, 'missing plain method') +test(('m1', 'method', A) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') + +class D(B, C): + + def m1(self): pass + +attrs = attrs_wo_objs(D) +test(('s', 'static method', A) in attrs, 'missing static method') +test(('c', 'method', C) in attrs, 'missing plain method') +test(('p', 'property', A) in attrs, 'missing property') +test(('m', 'method', B) in attrs, 'missing plain method') +test(('m1', 'method', D) in attrs, 'missing plain method') +test(('datablob', 'data', A) in attrs, 'missing data') + +args, varargs, varkw, defaults = inspect.getargspec(mod.eggs) +test(args == ['x', 'y'], 'mod.eggs args') +test(varargs == None, 'mod.eggs varargs') +test(varkw == None, 'mod.eggs varkw') +test(defaults == None, 'mod.eggs defaults') +test(inspect.formatargspec(args, varargs, varkw, defaults) == + '(x, y)', 'mod.eggs formatted argspec') +args, varargs, varkw, defaults = inspect.getargspec(mod.spam) +test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.spam args') +test(varargs == 'g', 'mod.spam varargs') +test(varkw == 'h', 'mod.spam varkw') +test(defaults == (3, (4, (5,))), 'mod.spam defaults') +test(inspect.formatargspec(args, varargs, varkw, defaults) == + '(a, b, c, d=3, (e, (f,))=(4, (5,)), *g, **h)', + 'mod.spam formatted argspec') +args, varargs, varkw, defaults = inspect.getargspec(A.m) +test(args == ['self'], 'A.m args') +test(varargs is None, 'A.m varargs') +test(varkw is None, 'A.m varkw') +test(defaults is None, 'A.m defaults') + +# Doc/lib/libinspect.tex claims there are 11 such functions +count = len(filter(lambda x:x.startswith('is'), dir(inspect))) +test(count == 11, "There are %d (not 11) is* functions", count) + +def sublistOfOne((foo)): return 1 + +args, varargs, varkw, defaults = inspect.getargspec(sublistOfOne) +test(args == [['foo']], 'sublistOfOne args') +test(varargs is None, 'sublistOfOne varargs') +test(varkw is None, 'sublistOfOne varkw') +test(defaults is None, 'sublistOfOn defaults') This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <pj...@us...> - 2008-01-18 00:23:50
|
Revision: 4053 http://jython.svn.sourceforge.net/jython/?rev=4053&view=rev Author: pjenvey Date: 2008-01-17 16:23:48 -0800 (Thu, 17 Jan 2008) Log Message: ----------- merge trunk r3200 back in, (it was lost in the 2.3 branch merge) and a little more of the same to fix inspect.getargspec and test_inspect: Don't check co_code on Jython since it doesn't exist. The means getargs doesn't correctly return anonymous tuple arguments Modified Paths: -------------- trunk/jython/Lib/inspect.py trunk/jython/Lib/test/regrtest.py trunk/jython/Lib/test/test_inspect.py Modified: trunk/jython/Lib/inspect.py =================================================================== --- trunk/jython/Lib/inspect.py 2008-01-18 00:03:05 UTC (rev 4052) +++ trunk/jython/Lib/inspect.py 2008-01-18 00:23:48 UTC (rev 4053) @@ -28,7 +28,7 @@ __author__ = 'Ka-Ping Yee <pi...@lf...>' __date__ = '1 Jan 2001' -import sys, os, types, string, re, dis, imp, tokenize, linecache +import sys, os, types, string, re, imp, tokenize, linecache # ----------------------------------------------------------- type-checking def ismodule(object): @@ -599,7 +599,11 @@ if not iscode(co): raise TypeError('arg is not a code object') - code = co.co_code + if not sys.platform.startswith('java'): + # Jython doesn't have co_code + code = co.co_code + import dis + nargs = co.co_argcount names = co.co_varnames args = list(names[:nargs]) Modified: trunk/jython/Lib/test/regrtest.py =================================================================== --- trunk/jython/Lib/test/regrtest.py 2008-01-18 00:03:05 UTC (rev 4052) +++ trunk/jython/Lib/test/regrtest.py 2008-01-18 00:23:48 UTC (rev 4053) @@ -1033,7 +1033,6 @@ test_eof test_frozen test_hexoct - test_inspect test_marshal test_new test_pep263 Modified: trunk/jython/Lib/test/test_inspect.py =================================================================== --- trunk/jython/Lib/test/test_inspect.py 2008-01-18 00:03:05 UTC (rev 4052) +++ trunk/jython/Lib/test/test_inspect.py 2008-01-18 00:23:48 UTC (rev 4053) @@ -63,7 +63,7 @@ # getargvalues, formatargspec, formatargvalues, currentframe, stack, trace # isdatadescriptor -from test.test_support import TestFailed, TESTFN +from test.test_support import TestFailed, TESTFN, is_jython import sys, imp, os, string def test(assertion, message, *args): @@ -95,7 +95,7 @@ except: tb = sys.exc_traceback -istest(inspect.isbuiltin, 'sys.exit') +istest(inspect.isbuiltin, 'ord') istest(inspect.isbuiltin, '[].append') istest(inspect.isclass, 'mod.StupidGit') istest(inspect.iscode, 'mod.spam.func_code') @@ -188,13 +188,15 @@ test(inspect.formatargvalues(args, varargs, varkw, locals) == '(x=11, y=14)', 'mod.fr formatted argvalues') -args, varargs, varkw, locals = inspect.getargvalues(mod.fr.f_back) -test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.fr.f_back args') -test(varargs == 'g', 'mod.fr.f_back varargs') -test(varkw == 'h', 'mod.fr.f_back varkw') -test(inspect.formatargvalues(args, varargs, varkw, locals) == - '(a=7, b=8, c=9, d=3, (e=4, (f=5,)), *g=(), **h={})', - 'mod.fr.f_back formatted argvalues') +if not is_jython: + # Jython can't handle this without co_code + args, varargs, varkw, locals = inspect.getargvalues(mod.fr.f_back) + test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.fr.f_back args') + test(varargs == 'g', 'mod.fr.f_back varargs') + test(varkw == 'h', 'mod.fr.f_back varkw') + test(inspect.formatargvalues(args, varargs, varkw, locals) == + '(a=7, b=8, c=9, d=3, (e=4, (f=5,)), *g=(), **h={})', + 'mod.fr.f_back formatted argvalues') for fname in files_to_clean_up: try: @@ -357,14 +359,16 @@ test(defaults == None, 'mod.eggs defaults') test(inspect.formatargspec(args, varargs, varkw, defaults) == '(x, y)', 'mod.eggs formatted argspec') -args, varargs, varkw, defaults = inspect.getargspec(mod.spam) -test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.spam args') -test(varargs == 'g', 'mod.spam varargs') -test(varkw == 'h', 'mod.spam varkw') -test(defaults == (3, (4, (5,))), 'mod.spam defaults') -test(inspect.formatargspec(args, varargs, varkw, defaults) == - '(a, b, c, d=3, (e, (f,))=(4, (5,)), *g, **h)', - 'mod.spam formatted argspec') +if not is_jython: + # Jython can't handle this without co_code + args, varargs, varkw, defaults = inspect.getargspec(mod.spam) + test(args == ['a', 'b', 'c', 'd', ['e', ['f']]], 'mod.spam args') + test(varargs == 'g', 'mod.spam varargs') + test(varkw == 'h', 'mod.spam varkw') + test(defaults == (3, (4, (5,))), 'mod.spam defaults') + test(inspect.formatargspec(args, varargs, varkw, defaults) == + '(a, b, c, d=3, (e, (f,))=(4, (5,)), *g, **h)', + 'mod.spam formatted argspec') args, varargs, varkw, defaults = inspect.getargspec(A.m) test(args == ['self'], 'A.m args') test(varargs is None, 'A.m varargs') @@ -378,7 +382,9 @@ def sublistOfOne((foo)): return 1 args, varargs, varkw, defaults = inspect.getargspec(sublistOfOne) -test(args == [['foo']], 'sublistOfOne args') +if not is_jython: + # Jython can't handle this without co_code + test(args == [['foo']], 'sublistOfOne args') test(varargs is None, 'sublistOfOne varargs') test(varkw is None, 'sublistOfOne varkw') test(defaults is None, 'sublistOfOn defaults') This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <pj...@us...> - 2008-02-08 20:39:48
|
Revision: 4113 http://jython.svn.sourceforge.net/jython/?rev=4113&view=rev Author: pjenvey Date: 2008-02-08 12:39:38 -0800 (Fri, 08 Feb 2008) Log Message: ----------- for changes to the compiler package, from http://svn.python.org/projects/python/branches/release23-maint/Lib@60660 Added Paths: ----------- trunk/jython/Lib/compiler/ trunk/jython/Lib/compiler/pycodegen.py trunk/jython/Lib/compiler/transformer.py Added: trunk/jython/Lib/compiler/pycodegen.py =================================================================== --- trunk/jython/Lib/compiler/pycodegen.py (rev 0) +++ trunk/jython/Lib/compiler/pycodegen.py 2008-02-08 20:39:38 UTC (rev 4113) @@ -0,0 +1,1384 @@ +import imp +import os +import marshal +import struct +import sys +import types +from cStringIO import StringIO + +from compiler import ast, parse, walk, syntax +from compiler import pyassem, misc, future, symbols +from compiler.consts import SC_LOCAL, SC_GLOBAL, SC_FREE, SC_CELL +from compiler.consts import CO_VARARGS, CO_VARKEYWORDS, CO_NEWLOCALS,\ + CO_NESTED, CO_GENERATOR, CO_GENERATOR_ALLOWED, CO_FUTURE_DIVISION +from compiler.pyassem import TupleArg + +# XXX The version-specific code can go, since this code only works with 2.x. +# Do we have Python 1.x or Python 2.x? +try: + VERSION = sys.version_info[0] +except AttributeError: + VERSION = 1 + +callfunc_opcode_info = { + # (Have *args, Have **args) : opcode + (0,0) : "CALL_FUNCTION", + (1,0) : "CALL_FUNCTION_VAR", + (0,1) : "CALL_FUNCTION_KW", + (1,1) : "CALL_FUNCTION_VAR_KW", +} + +LOOP = 1 +EXCEPT = 2 +TRY_FINALLY = 3 +END_FINALLY = 4 + +def compileFile(filename, display=0): + f = open(filename, 'U') + buf = f.read() + f.close() + mod = Module(buf, filename) + try: + mod.compile(display) + except SyntaxError: + raise + else: + f = open(filename + "c", "wb") + mod.dump(f) + f.close() + +def compile(source, filename, mode, flags=None, dont_inherit=None): + """Replacement for builtin compile() function""" + if flags is not None or dont_inherit is not None: + raise RuntimeError, "not implemented yet" + + if mode == "single": + gen = Interactive(source, filename) + elif mode == "exec": + gen = Module(source, filename) + elif mode == "eval": + gen = Expression(source, filename) + else: + raise ValueError("compile() 3rd arg must be 'exec' or " + "'eval' or 'single'") + gen.compile() + return gen.code + +class AbstractCompileMode: + + mode = None # defined by subclass + + def __init__(self, source, filename): + self.source = source + self.filename = filename + self.code = None + + def _get_tree(self): + tree = parse(self.source, self.mode) + misc.set_filename(self.filename, tree) + syntax.check(tree) + return tree + + def compile(self): + pass # implemented by subclass + + def getCode(self): + return self.code + +class Expression(AbstractCompileMode): + + mode = "eval" + + def compile(self): + tree = self._get_tree() + gen = ExpressionCodeGenerator(tree) + self.code = gen.getCode() + +class Interactive(AbstractCompileMode): + + mode = "single" + + def compile(self): + tree = self._get_tree() + gen = InteractiveCodeGenerator(tree) + self.code = gen.getCode() + +class Module(AbstractCompileMode): + + mode = "exec" + + def compile(self, display=0): + tree = self._get_tree() + gen = ModuleCodeGenerator(tree) + if display: + import pprint + print pprint.pprint(tree) + self.code = gen.getCode() + + def dump(self, f): + f.write(self.getPycHeader()) + marshal.dump(self.code, f) + + MAGIC = imp.get_magic() + + def getPycHeader(self): + # compile.c uses marshal to write a long directly, with + # calling the interface that would also generate a 1-byte code + # to indicate the type of the value. simplest way to get the + # same effect is to call marshal and then skip the code. + mtime = os.path.getmtime(self.filename) + mtime = struct.pack('<i', mtime) + return self.MAGIC + mtime + +class LocalNameFinder: + """Find local names in scope""" + def __init__(self, names=()): + self.names = misc.Set() + self.globals = misc.Set() + for name in names: + self.names.add(name) + + # XXX list comprehensions and for loops + + def getLocals(self): + for elt in self.globals.elements(): + if self.names.has_elt(elt): + self.names.remove(elt) + return self.names + + def visitDict(self, node): + pass + + def visitGlobal(self, node): + for name in node.names: + self.globals.add(name) + + def visitFunction(self, node): + self.names.add(node.name) + + def visitLambda(self, node): + pass + + def visitImport(self, node): + for name, alias in node.names: + self.names.add(alias or name) + + def visitFrom(self, node): + for name, alias in node.names: + self.names.add(alias or name) + + def visitClass(self, node): + self.names.add(node.name) + + def visitAssName(self, node): + self.names.add(node.name) + +def is_constant_false(node): + if isinstance(node, ast.Const): + if not node.value: + return 1 + return 0 + +class CodeGenerator: + """Defines basic code generator for Python bytecode + + This class is an abstract base class. Concrete subclasses must + define an __init__() that defines self.graph and then calls the + __init__() defined in this class. + + The concrete class must also define the class attributes + NameFinder, FunctionGen, and ClassGen. These attributes can be + defined in the initClass() method, which is a hook for + initializing these methods after all the classes have been + defined. + """ + + optimized = 0 # is namespace access optimized? + __initialized = None + class_name = None # provide default for instance variable + + def __init__(self): + if self.__initialized is None: + self.initClass() + self.__class__.__initialized = 1 + self.checkClass() + self.locals = misc.Stack() + self.setups = misc.Stack() + self.curStack = 0 + self.maxStack = 0 + self.last_lineno = None + self._setupGraphDelegation() + self._div_op = "BINARY_DIVIDE" + + # XXX set flags based on future features + futures = self.get_module().futures + for feature in futures: + if feature == "division": + self.graph.setFlag(CO_FUTURE_DIVISION) + self._div_op = "BINARY_TRUE_DIVIDE" + elif feature == "generators": + self.graph.setFlag(CO_GENERATOR_ALLOWED) + + def initClass(self): + """This method is called once for each class""" + + def checkClass(self): + """Verify that class is constructed correctly""" + try: + assert hasattr(self, 'graph') + assert getattr(self, 'NameFinder') + assert getattr(self, 'FunctionGen') + assert getattr(self, 'ClassGen') + except AssertionError, msg: + intro = "Bad class construction for %s" % self.__class__.__name__ + raise AssertionError, intro + + def _setupGraphDelegation(self): + self.emit = self.graph.emit + self.newBlock = self.graph.newBlock + self.startBlock = self.graph.startBlock + self.nextBlock = self.graph.nextBlock + self.setDocstring = self.graph.setDocstring + + def getCode(self): + """Return a code object""" + return self.graph.getCode() + + def mangle(self, name): + if self.class_name is not None: + return misc.mangle(name, self.class_name) + else: + return name + + def parseSymbols(self, tree): + s = symbols.SymbolVisitor() + walk(tree, s) + return s.scopes + + def get_module(self): + raise RuntimeError, "should be implemented by subclasses" + + # Next five methods handle name access + + def isLocalName(self, name): + return self.locals.top().has_elt(name) + + def storeName(self, name): + self._nameOp('STORE', name) + + def loadName(self, name): + self._nameOp('LOAD', name) + + def delName(self, name): + self._nameOp('DELETE', name) + + def _nameOp(self, prefix, name): + name = self.mangle(name) + scope = self.scope.check_name(name) + if scope == SC_LOCAL: + if not self.optimized: + self.emit(prefix + '_NAME', name) + else: + self.emit(prefix + '_FAST', name) + elif scope == SC_GLOBAL: + if not self.optimized: + self.emit(prefix + '_NAME', name) + else: + self.emit(prefix + '_GLOBAL', name) + elif scope == SC_FREE or scope == SC_CELL: + self.emit(prefix + '_DEREF', name) + else: + raise RuntimeError, "unsupported scope for var %s: %d" % \ + (name, scope) + + def _implicitNameOp(self, prefix, name): + """Emit name ops for names generated implicitly by for loops + + The interpreter generates names that start with a period or + dollar sign. The symbol table ignores these names because + they aren't present in the program text. + """ + if self.optimized: + self.emit(prefix + '_FAST', name) + else: + self.emit(prefix + '_NAME', name) + + # The set_lineno() function and the explicit emit() calls for + # SET_LINENO below are only used to generate the line number table. + # As of Python 2.3, the interpreter does not have a SET_LINENO + # instruction. pyassem treats SET_LINENO opcodes as a special case. + + def set_lineno(self, node, force=False): + """Emit SET_LINENO if necessary. + + The instruction is considered necessary if the node has a + lineno attribute and it is different than the last lineno + emitted. + + Returns true if SET_LINENO was emitted. + + There are no rules for when an AST node should have a lineno + attribute. The transformer and AST code need to be reviewed + and a consistent policy implemented and documented. Until + then, this method works around missing line numbers. + """ + lineno = getattr(node, 'lineno', None) + if lineno is not None and (lineno != self.last_lineno + or force): + self.emit('SET_LINENO', lineno) + self.last_lineno = lineno + return True + return False + + # The first few visitor methods handle nodes that generator new + # code objects. They use class attributes to determine what + # specialized code generators to use. + + NameFinder = LocalNameFinder + FunctionGen = None + ClassGen = None + + def visitModule(self, node): + self.scopes = self.parseSymbols(node) + self.scope = self.scopes[node] + self.emit('SET_LINENO', 0) + if node.doc: + self.emit('LOAD_CONST', node.doc) + self.storeName('__doc__') + lnf = walk(node.node, self.NameFinder(), verbose=0) + self.locals.push(lnf.getLocals()) + self.visit(node.node) + self.emit('LOAD_CONST', None) + self.emit('RETURN_VALUE') + + def visitExpression(self, node): + self.set_lineno(node) + self.scopes = self.parseSymbols(node) + self.scope = self.scopes[node] + self.visit(node.node) + self.emit('RETURN_VALUE') + + def visitFunction(self, node): + self._visitFuncOrLambda(node, isLambda=0) + if node.doc: + self.setDocstring(node.doc) + self.storeName(node.name) + + def visitLambda(self, node): + self._visitFuncOrLambda(node, isLambda=1) + + def _visitFuncOrLambda(self, node, isLambda=0): + gen = self.FunctionGen(node, self.scopes, isLambda, + self.class_name, self.get_module()) + walk(node.code, gen) + gen.finish() + self.set_lineno(node) + for default in node.defaults: + self.visit(default) + frees = gen.scope.get_free_vars() + if frees: + for name in frees: + self.emit('LOAD_CLOSURE', name) + self.emit('LOAD_CONST', gen) + self.emit('MAKE_CLOSURE', len(node.defaults)) + else: + self.emit('LOAD_CONST', gen) + self.emit('MAKE_FUNCTION', len(node.defaults)) + + def visitClass(self, node): + gen = self.ClassGen(node, self.scopes, + self.get_module()) + walk(node.code, gen) + gen.finish() + self.set_lineno(node) + self.emit('LOAD_CONST', node.name) + for base in node.bases: + self.visit(base) + self.emit('BUILD_TUPLE', len(node.bases)) + frees = gen.scope.get_free_vars() + for name in frees: + self.emit('LOAD_CLOSURE', name) + self.emit('LOAD_CONST', gen) + if frees: + self.emit('MAKE_CLOSURE', 0) + else: + self.emit('MAKE_FUNCTION', 0) + self.emit('CALL_FUNCTION', 0) + self.emit('BUILD_CLASS') + self.storeName(node.name) + + # The rest are standard visitor methods + + # The next few implement control-flow statements + + def visitIf(self, node): + end = self.newBlock() + numtests = len(node.tests) + for i in range(numtests): + test, suite = node.tests[i] + if is_constant_false(test): + # XXX will need to check generator stuff here + continue + self.set_lineno(test) + self.visit(test) + nextTest = self.newBlock() + self.emit('JUMP_IF_FALSE', nextTest) + self.nextBlock() + self.emit('POP_TOP') + self.visit(suite) + self.emit('JUMP_FORWARD', end) + self.startBlock(nextTest) + self.emit('POP_TOP') + if node.else_: + self.visit(node.else_) + self.nextBlock(end) + + def visitWhile(self, node): + self.set_lineno(node) + + loop = self.newBlock() + else_ = self.newBlock() + + after = self.newBlock() + self.emit('SETUP_LOOP', after) + + self.nextBlock(loop) + self.setups.push((LOOP, loop)) + + self.set_lineno(node, force=True) + self.visit(node.test) + self.emit('JUMP_IF_FALSE', else_ or after) + + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.body) + self.emit('JUMP_ABSOLUTE', loop) + + self.startBlock(else_) # or just the POPs if not else clause + self.emit('POP_TOP') + self.emit('POP_BLOCK') + self.setups.pop() + if node.else_: + self.visit(node.else_) + self.nextBlock(after) + + def visitFor(self, node): + start = self.newBlock() + anchor = self.newBlock() + after = self.newBlock() + self.setups.push((LOOP, start)) + + self.set_lineno(node) + self.emit('SETUP_LOOP', after) + self.visit(node.list) + self.emit('GET_ITER') + + self.nextBlock(start) + self.set_lineno(node, force=1) + self.emit('FOR_ITER', anchor) + self.visit(node.assign) + self.visit(node.body) + self.emit('JUMP_ABSOLUTE', start) + self.nextBlock(anchor) + self.emit('POP_BLOCK') + self.setups.pop() + if node.else_: + self.visit(node.else_) + self.nextBlock(after) + + def visitBreak(self, node): + if not self.setups: + raise SyntaxError, "'break' outside loop (%s, %d)" % \ + (node.filename, node.lineno) + self.set_lineno(node) + self.emit('BREAK_LOOP') + + def visitContinue(self, node): + if not self.setups: + raise SyntaxError, "'continue' outside loop (%s, %d)" % \ + (node.filename, node.lineno) + kind, block = self.setups.top() + if kind == LOOP: + self.set_lineno(node) + self.emit('JUMP_ABSOLUTE', block) + self.nextBlock() + elif kind == EXCEPT or kind == TRY_FINALLY: + self.set_lineno(node) + # find the block that starts the loop + top = len(self.setups) + while top > 0: + top = top - 1 + kind, loop_block = self.setups[top] + if kind == LOOP: + break + if kind != LOOP: + raise SyntaxError, "'continue' outside loop (%s, %d)" % \ + (node.filename, node.lineno) + self.emit('CONTINUE_LOOP', loop_block) + self.nextBlock() + elif kind == END_FINALLY: + msg = "'continue' not allowed inside 'finally' clause (%s, %d)" + raise SyntaxError, msg % (node.filename, node.lineno) + + def visitTest(self, node, jump): + end = self.newBlock() + for child in node.nodes[:-1]: + self.visit(child) + self.emit(jump, end) + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.nodes[-1]) + self.nextBlock(end) + + def visitAnd(self, node): + self.visitTest(node, 'JUMP_IF_FALSE') + + def visitOr(self, node): + self.visitTest(node, 'JUMP_IF_TRUE') + + def visitCompare(self, node): + self.visit(node.expr) + cleanup = self.newBlock() + for op, code in node.ops[:-1]: + self.visit(code) + self.emit('DUP_TOP') + self.emit('ROT_THREE') + self.emit('COMPARE_OP', op) + self.emit('JUMP_IF_FALSE', cleanup) + self.nextBlock() + self.emit('POP_TOP') + # now do the last comparison + if node.ops: + op, code = node.ops[-1] + self.visit(code) + self.emit('COMPARE_OP', op) + if len(node.ops) > 1: + end = self.newBlock() + self.emit('JUMP_FORWARD', end) + self.startBlock(cleanup) + self.emit('ROT_TWO') + self.emit('POP_TOP') + self.nextBlock(end) + + # list comprehensions + __list_count = 0 + + def visitListComp(self, node): + self.set_lineno(node) + # setup list + append = "$append%d" % self.__list_count + self.__list_count = self.__list_count + 1 + self.emit('BUILD_LIST', 0) + self.emit('DUP_TOP') + self.emit('LOAD_ATTR', 'append') + self._implicitNameOp('STORE', append) + + stack = [] + for i, for_ in zip(range(len(node.quals)), node.quals): + start, anchor = self.visit(for_) + cont = None + for if_ in for_.ifs: + if cont is None: + cont = self.newBlock() + self.visit(if_, cont) + stack.insert(0, (start, cont, anchor)) + + self._implicitNameOp('LOAD', append) + self.visit(node.expr) + self.emit('CALL_FUNCTION', 1) + self.emit('POP_TOP') + + for start, cont, anchor in stack: + if cont: + skip_one = self.newBlock() + self.emit('JUMP_FORWARD', skip_one) + self.startBlock(cont) + self.emit('POP_TOP') + self.nextBlock(skip_one) + self.emit('JUMP_ABSOLUTE', start) + self.startBlock(anchor) + self._implicitNameOp('DELETE', append) + + self.__list_count = self.__list_count - 1 + + def visitListCompFor(self, node): + start = self.newBlock() + anchor = self.newBlock() + + self.visit(node.list) + self.emit('GET_ITER') + self.nextBlock(start) + self.set_lineno(node, force=True) + self.emit('FOR_ITER', anchor) + self.nextBlock() + self.visit(node.assign) + return start, anchor + + def visitListCompIf(self, node, branch): + self.set_lineno(node, force=True) + self.visit(node.test) + self.emit('JUMP_IF_FALSE', branch) + self.newBlock() + self.emit('POP_TOP') + + # exception related + + def visitAssert(self, node): + # XXX would be interesting to implement this via a + # transformation of the AST before this stage + end = self.newBlock() + self.set_lineno(node) + # XXX __debug__ and AssertionError appear to be special cases + # -- they are always loaded as globals even if there are local + # names. I guess this is a sort of renaming op. + self.emit('LOAD_GLOBAL', '__debug__') + self.emit('JUMP_IF_FALSE', end) + self.nextBlock() + self.emit('POP_TOP') + self.visit(node.test) + self.emit('JUMP_IF_TRUE', end) + self.nextBlock() + self.emit('POP_TOP') + self.emit('LOAD_GLOBAL', 'AssertionError') + if node.fail: + self.visit(node.fail) + self.emit('RAISE_VARARGS', 2) + else: + self.emit('RAISE_VARARGS', 1) + self.nextBlock(end) + self.emit('POP_TOP') + + def visitRaise(self, node): + self.set_lineno(node) + n = 0 + if node.expr1: + self.visit(node.expr1) + n = n + 1 + if node.expr2: + self.visit(node.expr2) + n = n + 1 + if node.expr3: + self.visit(node.expr3) + n = n + 1 + self.emit('RAISE_VARARGS', n) + + def visitTryExcept(self, node): + body = self.newBlock() + handlers = self.newBlock() + end = self.newBlock() + if node.else_: + lElse = self.newBlock() + else: + lElse = end + self.set_lineno(node) + self.emit('SETUP_EXCEPT', handlers) + self.nextBlock(body) + self.setups.push((EXCEPT, body)) + self.visit(node.body) + self.emit('POP_BLOCK') + self.setups.pop() + self.emit('JUMP_FORWARD', lElse) + self.startBlock(handlers) + + last = len(node.handlers) - 1 + for i in range(len(node.handlers)): + expr, target, body = node.handlers[i] + self.set_lineno(expr) + if expr: + self.emit('DUP_TOP') + self.visit(expr) + self.emit('COMPARE_OP', 'exception match') + next = self.newBlock() + self.emit('JUMP_IF_FALSE', next) + self.nextBlock() + self.emit('POP_TOP') + self.emit('POP_TOP') + if target: + self.visit(target) + else: + self.emit('POP_TOP') + self.emit('POP_TOP') + self.visit(body) + self.emit('JUMP_FORWARD', end) + if expr: + self.nextBlock(next) + else: + self.nextBlock() + if expr: # XXX + self.emit('POP_TOP') + self.emit('END_FINALLY') + if node.else_: + self.nextBlock(lElse) + self.visit(node.else_) + self.nextBlock(end) + + def visitTryFinally(self, node): + body = self.newBlock() + final = self.newBlock() + self.set_lineno(node) + self.emit('SETUP_FINALLY', final) + self.nextBlock(body) + self.setups.push((TRY_FINALLY, body)) + self.visit(node.body) + self.emit('POP_BLOCK') + self.setups.pop() + self.emit('LOAD_CONST', None) + self.nextBlock(final) + self.setups.push((END_FINALLY, final)) + self.visit(node.final) + self.emit('END_FINALLY') + self.setups.pop() + + # misc + + def visitDiscard(self, node): + self.set_lineno(node) + self.visit(node.expr) + self.emit('POP_TOP') + + def visitConst(self, node): + self.emit('LOAD_CONST', node.value) + + def visitKeyword(self, node): + self.emit('LOAD_CONST', node.name) + self.visit(node.expr) + + def visitGlobal(self, node): + # no code to generate + pass + + def visitName(self, node): + self.set_lineno(node) + self.loadName(node.name) + + def visitPass(self, node): + self.set_lineno(node) + + def visitImport(self, node): + self.set_lineno(node) + for name, alias in node.names: + if VERSION > 1: + self.emit('LOAD_CONST', None) + self.emit('IMPORT_NAME', name) + mod = name.split(".")[0] + if alias: + self._resolveDots(name) + self.storeName(alias) + else: + self.storeName(mod) + + def visitFrom(self, node): + self.set_lineno(node) + fromlist = map(lambda (name, alias): name, node.names) + if VERSION > 1: + self.emit('LOAD_CONST', tuple(fromlist)) + self.emit('IMPORT_NAME', node.modname) + for name, alias in node.names: + if VERSION > 1: + if name == '*': + self.namespace = 0 + self.emit('IMPORT_STAR') + # There can only be one name w/ from ... import * + assert len(node.names) == 1 + return + else: + self.emit('IMPORT_FROM', name) + self._resolveDots(name) + self.storeName(alias or name) + else: + self.emit('IMPORT_FROM', name) + self.emit('POP_TOP') + + def _resolveDots(self, name): + elts = name.split(".") + if len(elts) == 1: + return + for elt in elts[1:]: + self.emit('LOAD_ATTR', elt) + + def visitGetattr(self, node): + self.visit(node.expr) + self.emit('LOAD_ATTR', self.mangle(node.attrname)) + + # next five implement assignments + + def visitAssign(self, node): + self.set_lineno(node) + self.visit(node.expr) + dups = len(node.nodes) - 1 + for i in range(len(node.nodes)): + elt = node.nodes[i] + if i < dups: + self.emit('DUP_TOP') + if isinstance(elt, ast.Node): + self.visit(elt) + + def visitAssName(self, node): + if node.flags == 'OP_ASSIGN': + self.storeName(node.name) + elif node.flags == 'OP_DELETE': + self.set_lineno(node) + self.delName(node.name) + else: + print "oops", node.flags + + def visitAssAttr(self, node): + self.visit(node.expr) + if node.flags == 'OP_ASSIGN': + self.emit('STORE_ATTR', self.mangle(node.attrname)) + elif node.flags == 'OP_DELETE': + self.emit('DELETE_ATTR', self.mangle(node.attrname)) + else: + print "warning: unexpected flags:", node.flags + print node + + def _visitAssSequence(self, node, op='UNPACK_SEQUENCE'): + if findOp(node) != 'OP_DELETE': + self.emit(op, len(node.nodes)) + for child in node.nodes: + self.visit(child) + + if VERSION > 1: + visitAssTuple = _visitAssSequence + visitAssList = _visitAssSequence + else: + def visitAssTuple(self, node): + self._visitAssSequence(node, 'UNPACK_TUPLE') + + def visitAssList(self, node): + self._visitAssSequence(node, 'UNPACK_LIST') + + # augmented assignment + + def visitAugAssign(self, node): + self.set_lineno(node) + aug_node = wrap_aug(node.node) + self.visit(aug_node, "load") + self.visit(node.expr) + self.emit(self._augmented_opcode[node.op]) + self.visit(aug_node, "store") + + _augmented_opcode = { + '+=' : 'INPLACE_ADD', + '-=' : 'INPLACE_SUBTRACT', + '*=' : 'INPLACE_MULTIPLY', + '/=' : 'INPLACE_DIVIDE', + '//=': 'INPLACE_FLOOR_DIVIDE', + '%=' : 'INPLACE_MODULO', + '**=': 'INPLACE_POWER', + '>>=': 'INPLACE_RSHIFT', + '<<=': 'INPLACE_LSHIFT', + '&=' : 'INPLACE_AND', + '^=' : 'INPLACE_XOR', + '|=' : 'INPLACE_OR', + } + + def visitAugName(self, node, mode): + if mode == "load": + self.loadName(node.name) + elif mode == "store": + self.storeName(node.name) + + def visitAugGetattr(self, node, mode): + if mode == "load": + self.visit(node.expr) + self.emit('DUP_TOP') + self.emit('LOAD_ATTR', self.mangle(node.attrname)) + elif mode == "store": + self.emit('ROT_TWO') + self.emit('STORE_ATTR', self.mangle(node.attrname)) + + def visitAugSlice(self, node, mode): + if mode == "load": + self.visitSlice(node, 1) + elif mode == "store": + slice = 0 + if node.lower: + slice = slice | 1 + if node.upper: + slice = slice | 2 + if slice == 0: + self.emit('ROT_TWO') + elif slice == 3: + self.emit('ROT_FOUR') + else: + self.emit('ROT_THREE') + self.emit('STORE_SLICE+%d' % slice) + + def visitAugSubscript(self, node, mode): + if len(node.subs) > 1: + raise SyntaxError, "augmented assignment to tuple is not possible" + if mode == "load": + self.visitSubscript(node, 1) + elif mode == "store": + self.emit('ROT_THREE') + self.emit('STORE_SUBSCR') + + def visitExec(self, node): + self.visit(node.expr) + if node.locals is None: + self.emit('LOAD_CONST', None) + else: + self.visit(node.locals) + if node.globals is None: + self.emit('DUP_TOP') + else: + self.visit(node.globals) + self.emit('EXEC_STMT') + + def visitCallFunc(self, node): + pos = 0 + kw = 0 + self.set_lineno(node) + self.visit(node.node) + for arg in node.args: + self.visit(arg) + if isinstance(arg, ast.Keyword): + kw = kw + 1 + else: + pos = pos + 1 + if node.star_args is not None: + self.visit(node.star_args) + if node.dstar_args is not None: + self.visit(node.dstar_args) + have_star = node.star_args is not None + have_dstar = node.dstar_args is not None + opcode = callfunc_opcode_info[have_star, have_dstar] + self.emit(opcode, kw << 8 | pos) + + def visitPrint(self, node, newline=0): + self.set_lineno(node) + if node.dest: + self.visit(node.dest) + for child in node.nodes: + if node.dest: + self.emit('DUP_TOP') + self.visit(child) + if node.dest: + self.emit('ROT_TWO') + self.emit('PRINT_ITEM_TO') + else: + self.emit('PRINT_ITEM') + if node.dest and not newline: + self.emit('POP_TOP') + + def visitPrintnl(self, node): + self.visitPrint(node, newline=1) + if node.dest: + self.emit('PRINT_NEWLINE_TO') + else: + self.emit('PRINT_NEWLINE') + + def visitReturn(self, node): + self.set_lineno(node) + self.visit(node.value) + self.emit('RETURN_VALUE') + + def visitYield(self, node): + self.set_lineno(node) + self.visit(node.value) + self.emit('YIELD_VALUE') + + # slice and subscript stuff + + def visitSlice(self, node, aug_flag=None): + # aug_flag is used by visitAugSlice + self.visit(node.expr) + slice = 0 + if node.lower: + self.visit(node.lower) + slice = slice | 1 + if node.upper: + self.visit(node.upper) + slice = slice | 2 + if aug_flag: + if slice == 0: + self.emit('DUP_TOP') + elif slice == 3: + self.emit('DUP_TOPX', 3) + else: + self.emit('DUP_TOPX', 2) + if node.flags == 'OP_APPLY': + self.emit('SLICE+%d' % slice) + elif node.flags == 'OP_ASSIGN': + self.emit('STORE_SLICE+%d' % slice) + elif node.flags == 'OP_DELETE': + self.emit('DELETE_SLICE+%d' % slice) + else: + print "weird slice", node.flags + raise + + def visitSubscript(self, node, aug_flag=None): + self.visit(node.expr) + for sub in node.subs: + self.visit(sub) + if aug_flag: + self.emit('DUP_TOPX', 2) + if len(node.subs) > 1: + self.emit('BUILD_TUPLE', len(node.subs)) + if node.flags == 'OP_APPLY': + self.emit('BINARY_SUBSCR') + elif node.flags == 'OP_ASSIGN': + self.emit('STORE_SUBSCR') + elif node.flags == 'OP_DELETE': + self.emit('DELETE_SUBSCR') + + # binary ops + + def binaryOp(self, node, op): + self.visit(node.left) + self.visit(node.right) + self.emit(op) + + def visitAdd(self, node): + return self.binaryOp(node, 'BINARY_ADD') + + def visitSub(self, node): + return self.binaryOp(node, 'BINARY_SUBTRACT') + + def visitMul(self, node): + return self.binaryOp(node, 'BINARY_MULTIPLY') + + def visitDiv(self, node): + return self.binaryOp(node, self._div_op) + + def visitFloorDiv(self, node): + return self.binaryOp(node, 'BINARY_FLOOR_DIVIDE') + + def visitMod(self, node): + return self.binaryOp(node, 'BINARY_MODULO') + + def visitPower(self, node): + return self.binaryOp(node, 'BINARY_POWER') + + def visitLeftShift(self, node): + return self.binaryOp(node, 'BINARY_LSHIFT') + + def visitRightShift(self, node): + return self.binaryOp(node, 'BINARY_RSHIFT') + + # unary ops + + def unaryOp(self, node, op): + self.visit(node.expr) + self.emit(op) + + def visitInvert(self, node): + return self.unaryOp(node, 'UNARY_INVERT') + + def visitUnarySub(self, node): + return self.unaryOp(node, 'UNARY_NEGATIVE') + + def visitUnaryAdd(self, node): + return self.unaryOp(node, 'UNARY_POSITIVE') + + def visitUnaryInvert(self, node): + return self.unaryOp(node, 'UNARY_INVERT') + + def visitNot(self, node): + return self.unaryOp(node, 'UNARY_NOT') + + def visitBackquote(self, node): + return self.unaryOp(node, 'UNARY_CONVERT') + + # bit ops + + def bitOp(self, nodes, op): + self.visit(nodes[0]) + for node in nodes[1:]: + self.visit(node) + self.emit(op) + + def visitBitand(self, node): + return self.bitOp(node.nodes, 'BINARY_AND') + + def visitBitor(self, node): + return self.bitOp(node.nodes, 'BINARY_OR') + + def visitBitxor(self, node): + return self.bitOp(node.nodes, 'BINARY_XOR') + + # object constructors + + def visitEllipsis(self, node): + self.emit('LOAD_CONST', Ellipsis) + + def visitTuple(self, node): + self.set_lineno(node) + for elt in node.nodes: + self.visit(elt) + self.emit('BUILD_TUPLE', len(node.nodes)) + + def visitList(self, node): + self.set_lineno(node) + for elt in node.nodes: + self.visit(elt) + self.emit('BUILD_LIST', len(node.nodes)) + + def visitSliceobj(self, node): + for child in node.nodes: + self.visit(child) + self.emit('BUILD_SLICE', len(node.nodes)) + + def visitDict(self, node): + self.set_lineno(node) + self.emit('BUILD_MAP', 0) + for k, v in node.items: + self.emit('DUP_TOP') + self.visit(k) + self.visit(v) + self.emit('ROT_THREE') + self.emit('STORE_SUBSCR') + +class NestedScopeMixin: + """Defines initClass() for nested scoping (Python 2.2-compatible)""" + def initClass(self): + self.__class__.NameFinder = LocalNameFinder + self.__class__.FunctionGen = FunctionCodeGenerator + self.__class__.ClassGen = ClassCodeGenerator + +class ModuleCodeGenerator(NestedScopeMixin, CodeGenerator): + __super_init = CodeGenerator.__init__ + + scopes = None + + def __init__(self, tree): + self.graph = pyassem.PyFlowGraph("<module>", tree.filename) + self.futures = future.find_futures(tree) + self.__super_init() + walk(tree, self) + + def get_module(self): + return self + +class ExpressionCodeGenerator(NestedScopeMixin, CodeGenerator): + __super_init = CodeGenerator.__init__ + + scopes = None + futures = () + + def __init__(self, tree): + self.graph = pyassem.PyFlowGraph("<expression>", tree.filename) + self.__super_init() + walk(tree, self) + + def get_module(self): + return self + +class InteractiveCodeGenerator(NestedScopeMixin, CodeGenerator): + + __super_init = CodeGenerator.__init__ + + scopes = None + futures = () + + def __init__(self, tree): + self.graph = pyassem.PyFlowGraph("<interactive>", tree.filename) + self.__super_init() + self.set_lineno(tree) + walk(tree, self) + self.emit('RETURN_VALUE') + + def get_module(self): + return self + + def visitDiscard(self, node): + # XXX Discard means it's an expression. Perhaps this is a bad + # name. + self.visit(node.expr) + self.emit('PRINT_EXPR') + +class AbstractFunctionCode: + optimized = 1 + lambdaCount = 0 + + def __init__(self, func, scopes, isLambda, class_name, mod): + self.class_name = class_name + self.module = mod + if isLambda: + klass = FunctionCodeGenerator + name = "<lambda.%d>" % klass.lambdaCount + klass.lambdaCount = klass.lambdaCount + 1 + else: + name = func.name + args, hasTupleArg = generateArgList(func.argnames) + self.graph = pyassem.PyFlowGraph(name, func.filename, args, + optimized=1) + self.isLambda = isLambda + self.super_init() + + if not isLambda and func.doc: + self.setDocstring(func.doc) + + lnf = walk(func.code, self.NameFinder(args), verbose=0) + self.locals.push(lnf.getLocals()) + if func.varargs: + self.graph.setFlag(CO_VARARGS) + if func.kwargs: + self.graph.setFlag(CO_VARKEYWORDS) + self.set_lineno(func) + if hasTupleArg: + self.generateArgUnpack(func.argnames) + + def get_module(self): + return self.module + + def finish(self): + self.graph.startExitBlock() + if not self.isLambda: + self.emit('LOAD_CONST', None) + self.emit('RETURN_VALUE') + + def generateArgUnpack(self, args): + for i in range(len(args)): + arg = args[i] + if type(arg) == types.TupleType: + self.emit('LOAD_FAST', '.%d' % (i * 2)) + self.unpackSequence(arg) + + def unpackSequence(self, tup): + if VERSION > 1: + self.emit('UNPACK_SEQUENCE', len(tup)) + else: + self.emit('UNPACK_TUPLE', len(tup)) + for elt in tup: + if type(elt) == types.TupleType: + self.unpackSequence(elt) + else: + self._nameOp('STORE', elt) + + unpackTuple = unpackSequence + +class FunctionCodeGenerator(NestedScopeMixin, AbstractFunctionCode, + CodeGenerator): + super_init = CodeGenerator.__init__ # call be other init + scopes = None + + __super_init = AbstractFunctionCode.__init__ + + def __init__(self, func, scopes, isLambda, class_name, mod): + self.scopes = scopes + self.scope = scopes[func] + self.__super_init(func, scopes, isLambda, class_name, mod) + self.graph.setFreeVars(self.scope.get_free_vars()) + self.graph.setCellVars(self.scope.get_cell_vars()) + if self.scope.generator is not None: + self.graph.setFlag(CO_GENERATOR) + +class AbstractClassCode: + + def __init__(self, klass, scopes, module): + self.class_name = klass.name + self.module = module + self.graph = pyassem.PyFlowGraph(klass.name, klass.filename, + optimized=0, klass=1) + self.super_init() + lnf = walk(klass.code, self.NameFinder(), verbose=0) + self.locals.push(lnf.getLocals()) + self.graph.setFlag(CO_NEWLOCALS) + if klass.doc: + self.setDocstring(klass.doc) + + def get_module(self): + return self.module + + def finish(self): + self.graph.startExitBlock() + self.emit('LOAD_LOCALS') + self.emit('RETURN_VALUE') + +class ClassCodeGenerator(NestedScopeMixin, AbstractClassCode, CodeGenerator): + super_init = CodeGenerator.__init__ + scopes = None + + __super_init = AbstractClassCode.__init__ + + def __init__(self, klass, scopes, module): + self.scopes = scopes + self.scope = scopes[klass] + self.__super_init(klass, scopes, module) + self.graph.setFreeVars(self.scope.get_free_vars()) + self.graph.setCellVars(self.scope.get_cell_vars()) + self.set_lineno(klass) + self.emit("LOAD_GLOBAL", "__name__") + self.storeName("__module__") + if klass.doc: + self.emit("LOAD_CONST", klass.doc) + self.storeName('__doc__') + +def generateArgList(arglist): + """Generate an arg list marking TupleArgs""" + args = [] + extra = [] + count = 0 + for i in range(len(arglist)): + elt = arglist[i] + if type(elt) == types.StringType: + args.append(elt) + elif type(elt) == types.TupleType: + args.append(TupleArg(i * 2, elt)) + extra.extend(misc.flatten(elt)) + count = count + 1 + else: + raise ValueError, "unexpect argument type:", elt + return args + extra, count + +def findOp(node): + """Find the op (DELETE, LOAD, STORE) in an AssTuple tree""" + v = OpFinder() + walk(node, v, verbose=0) + return v.op + +class OpFinder: + def __init__(self): + self.op = None + def visitAssName(self, node): + if self.op is None: + self.op = node.flags + elif self.op != node.flags: + raise ValueError, "mixed ops in stmt" + visitAssAttr = visitAssName + visitSubscript = visitAssName + +class Delegator: + """Base class to support delegation for augmented assignment nodes + + To generator code for augmented assignments, we use the following + wrapper classes. In visitAugAssign, the left-hand expression node + is visited twice. The first time the visit uses the normal method + for that node . The second time the visit uses a different method + that generates the appropriate code to perform the assignment. + These delegator classes wrap the original AST nodes in order to + support the variant visit methods. + """ + def __init__(self, obj): + self.obj = obj + + def __getattr__(self, attr): + return getattr(self.obj, attr) + +class AugGetattr(Delegator): + pass + +class AugName(Delegator): + pass + +class AugSlice(Delegator): + pass + +class AugSubscript(Delegator): + pass + +wrapper = { + ast.Getattr: AugGetattr, + ast.Name: AugName, + ast.Slice: AugSlice, + ast.Subscript: AugSubscript, + } + +def wrap_aug(node): + return wrapper[node.__class__](node) + +if __name__ == "__main__": + for file in sys.argv[1:]: + compileFile(file) Added: trunk/jython/Lib/compiler/transformer.py =================================================================== --- trunk/jython/Lib/compiler/transformer.py (rev 0) +++ trunk/jython/Lib/compiler/transformer.py 2008-02-08 20:39:38 UTC (rev 4113) @@ -0,0 +1,1382 @@ +"""Parse tree transformation module. + +Transforms Python source code into an abstract syntax tree (AST) +defined in the ast module. + +The simplest ways to invoke this module are via parse and parseFile. +parse(buf) -> AST +parseFile(path) -> AST +""" + +# Original version written by Greg Stein (gs...@ly...) +# and Bill Tutt (ras...@li...) +# February 1997. +# +# Modifications and improvements for Python 2.0 by Jeremy Hylton and +# Mark Hammond + +# Portions of this file are: +# Copyright (C) 1997-1998 Greg Stein. All Rights Reserved. +# +# This module is provided under a BSD-ish license. See +# http://www.opensource.org/licenses/bsd-license.html +# and replace OWNER, ORGANIZATION, and YEAR as appropriate. + +from ast import * +import parser +# Care must be taken to use only symbols and tokens defined in Python +# 1.5.2 for code branches executed in 1.5.2 +import symbol +import token +import sys + +error = 'walker.error' + +from consts import CO_VARARGS, CO_VARKEYWORDS +from consts import OP_ASSIGN, OP_DELETE, OP_APPLY + +def parseFile(path): + f = open(path) + # XXX The parser API tolerates files without a trailing newline, + # but not strings without a trailing newline. Always add an extra + # newline to the file contents, since we're going through the string + # version of the API. + src = f.read() + "\n" + f.close() + return parse(src) + +def parse(buf, mode="exec"): + if mode == "exec" or mode == "single": + return Transformer().parsesuite(buf) + elif mode == "eval": + return Transformer().parseexpr(buf) + else: + raise ValueError("compile() arg 3 must be" + " 'exec' or 'eval' or 'single'") + +def asList(nodes): + l = [] + for item in nodes: + if hasattr(item, "asList"): + l.append(item.asList()) + else: + if type(item) is type( (None, None) ): + l.append(tuple(asList(item))) + elif type(item) is type( [] ): + l.append(asList(item)) + else: + l.append(item) + return l + +def Node(*args): + kind = args[0] + if nodes.has_key(kind): + try: + return nodes[kind](*args[1:]) + except TypeError: + print nodes[kind], len(args), args + raise + else: + raise error, "Can't find appropriate Node type: %s" % str(args) + #return apply(ast.Node, args) + +class Transformer: + """Utility object for transforming Python parse trees. + + Exposes the following methods: + tree = transform(ast_tree) + tree = parsesuite(text) + tree = parseexpr(text) + tree = parsefile(fileob | filename) + """ + + def __init__(self): + self._dispatch = {} + for value, name in symbol.sym_name.items(): + if hasattr(self, name): + self._dispatch[value] = getattr(self, name) + self._dispatch[token.NEWLINE] = self.com_NEWLINE + self._atom_dispatch = {token.LPAR: self.atom_lpar, + token.LSQB: self.atom_lsqb, + token.LBRACE: self.atom_lbrace, + token.BACKQUOTE: self.atom_backquote, + token.NUMBER: self.atom_number, + token.STRING: self.atom_string, + token.NAME: self.atom_name, + } + self.encoding = None + + def transform(self, tree): + """Transform an AST into a modified parse tree.""" + if type(tree) != type(()) and type(tree) != type([]): + tree = parser.ast2tuple(tree, line_info=1) + return self.compile_node(tree) + + def parsesuite(self, text): + """Return a modified parse tree for the given suite text.""" + # Hack for handling non-native line endings on non-DOS like OSs. + # this can go now we have universal newlines? + text = text.replace('\x0d', '') + return self.transform(parser.suite(text)) + + def parseexpr(self, text): + """Return a modified parse tree for the given expression text.""" + return self.transform(parser.expr(text)) + + def parsefile(self, file): + """Return a modified parse tree for the contents of the given file.""" + if type(file) == type(''): + file = open(file) + return self.parsesuite(file.read()) + + # -------------------------------------------------------------- + # + # PRIVATE METHODS + # + + def compile_node(self, node): + ### emit a line-number node? + n = node[0] + + if n == symbol.encoding_decl: + self.encoding = node[2] + node = node[1] + n = node[0] + + if n == symbol.single_input: + return self.single_input(node[1:]) + if n == symbol.file_input: + return self.file_input(node[1:]) + if n == symbol.eval_input: + return self.eval_input(node[1:]) + if n == symbol.lambdef: + return self.lambdef(node[1:]) + if n == symbol.funcdef: + return self.funcdef(node[1:]) + if n == symbol.classdef: + return self.classdef(node[1:]) + + raise error, ('unexpected node type', n) + + def single_input(self, node): + ### do we want to do anything about being "interactive" ? + + # NEWLINE | simple_stmt | compound_stmt NEWLINE + n = node[0][0] + if n != token.NEWLINE: + return self.com_stmt(node[0]) + + return Pass() + + def file_input(self, nodelist): + doc = self.get_docstring(nodelist, symbol.file_input) + if doc is not None: + i = 1 + else: + i = 0 + stmts = [] + for node in nodelist[i:]: + if node[0] != token.ENDMARKER and node[0] != token.NEWLINE: + self.com_append_stmt(stmts, node) + return Module(doc, Stmt(stmts)) + + def eval_input(self, nodelist): + # from the built-in function input() + ### is this sufficient? + return Expression(self.com_node(nodelist[0])) + + def funcdef(self, nodelist): + # funcdef: 'def' NAME parameters ':' suite + # parameters: '(' [varargslist] ')' + + lineno = nodelist[1][2] + name = nodelist[1][1] + args = nodelist[2][2] + + if args[0] == symbol.varargslist: + names, defaults, flags = self.com_arglist(args[1:]) + else: + names = defaults = () + flags = 0 + doc = self.get_docstring(nodelist[4]) + + # code for function + code = self.com_node(nodelist[4]) + + if doc is not None: + assert isinstance(code, Stmt) + assert isinstance(code.nodes[0], Discard) + del code.nodes[0] + n = Function(name, names, defaults, flags, doc, code) + n.lineno = lineno + return n + + def lambdef(self, nodelist): + # lambdef: 'lambda' [varargslist] ':' test + if nodelist[2][0] == symbol.varargslist: + names, defaults, flags = self.com_arglist(nodelist[2][1:]) + else: + names = defaults = () + flags = 0 + + # code for lambda + code = self.com_node(nodelist[-1]) + + n = Lambda(names, defaults, flags, code) + n.lineno = nodelist[1][2] + return n + + def classdef(self, nodelist): + # classdef: 'class' NAME ['(' testlist ')'] ':' suite + + name = nodelist[1][1] + doc = self.get_docstring(nodelist[-1]) + if nodelist[2][0] == token.COLON: + bases = [] + else: + bases = self.com_bases(nodelist[3]) + + # code for class + code = self.com_node(nodelist[-1]) + + if doc is not None: + assert isinstance(code, Stmt) + assert isinstance(code.nodes[0], Discard) + del code.nodes[0] + + n = Class(name, bases, doc, code) + n.lineno = nodelist[1][2] + return n + + def stmt(self, nodelist): + return self.com_stmt(nodelist[0]) + + small_stmt = stmt + flow_stmt = stmt + compound_stmt = stmt + + def simple_stmt(self, nodelist): + # small_stmt (';' small_stmt)* [';'] NEWLINE + stmts = [] + for i in range(0, len(nodelist), 2): + self.com_append_stmt(stmts, nodelist[i]) + return Stmt(stmts) + + def parameters(self, nodelist): + raise error + + def varargslist(self, nodelist): + raise error + + def fpdef(self, nodelist): + raise error + + def fplist(self, nodelist): + raise error + + def dotted_name(self, nodelist): + raise error + + def comp_op(self, nodelist): + raise error + + def trailer(self, nodelist): + raise error + + def sliceop(self, nodelist): + raise error + + def argument(self, nodelist): + raise error + + # -------------------------------------------------------------- + # + # STATEMENT NODES (invoked by com_node()) + # + + def expr_stmt(self, nodelist): + # augassign testlist | testlist ('=' testlist)* + en = nodelist[-1] + exprNode = self.lookup_node(en)(en[1:]) + if len(nodelist) == 1: + n = Discard(exprNode) + n.lineno = exprNode.lineno + return n + if nodelist[1][0] == token.EQUAL: + nodesl = [] + for i in range(0, len(nodelist) - 2, 2): + nodesl.append(self.com_assign(nodelist[i], OP_ASSIGN)) + n = Assign(nodesl, exprNode) + n.lineno = nodelist[1][2] + else: + lval = self.com_augassign(nodelist[0]) + op = self.com_augassign_op(nodelist[1]) + n = AugAssign(lval, op[1], exprNode) + n.lineno = op[2] + return n + + def print_stmt(self, nodelist): + # print ([ test (',' test)* [','] ] | '>>' test [ (',' test)+ [','] ]) + items = [] + if len(nodelist) == 1: + start = 1 + dest = None + elif nodelist[1][0] == token.RIGHTSHIFT: + assert len(nodelist) == 3 \ + or nodelist[3][0] == token.COMMA + dest = self.com_node(nodelist[2]) + start = 4 + else: + dest = None + start = 1 + for i in range(start, len(nodelist), 2): + items.append(self.com_node(nodelist[i])) + if nodelist[-1][0] == token.COMMA: + n = Print(items, dest) + n.lineno = nodelist[0][2] + return n + n = Printnl(items, dest) + n.lineno = nodelist[0][2] + return n + + def del_stmt(self, nodelist): + return self.com_assign(nodelist[1], OP_DELETE) + + def pass_stmt(self, nodelist): + n = Pass() + n.lineno = nodelist[0][2] + return n + + def break_stmt(self, nodelist): + n = Break() + n.lineno = nodelist[0][2] + return n + + def continue_stmt(self, nodelist): + n = Continue() + n.lineno = nodelist[0][2] + return n + + def return_stmt(self, nodelist): + # return: [testlist] + if len(nodelist) < 2: + n = Return(Const(None)) + n.lineno = nodelist[0][2] + return n + n = Return(self.com_node(nodelist[1])) + n.lineno = nodelist[0][2] + return n + + def yield_stmt(self, nodelist): + n = Yield(self.com_node(nodelist[1])) + n.lineno = nodelist[0][2] + return n + + def raise_stmt(self, nodelist): + # raise: [test [',' test [',' test]]] + if len(nodelist) > 5: + expr3 = self.com_node(nodelist[5]) + else: + expr3 = None + if len(nodelist) > 3: + expr2 = self.com_node(nodelist[3]) + else: + expr2 = None + if len(nodelist) > 1: + expr1 = self.com_node(nodelist[1]) + else: + expr1 = None + n = Raise(expr1, expr2, expr3) + n.lineno = nodelist[0][2] + return n + + def import_stmt(self, nodelist): + # import_stmt: 'import' dotted_as_name (',' dotted_as_name)* | + # from: 'from' dotted_name 'import' + # ('*' | import_as_name (',' import_as_name)*) + if nodelist[0][1] == 'from': + names = [] + if nodelist[3][0] == token.NAME: + for i in range(3, len(nodelist), 2): + names.append((nodelist[i][1], None)) + else: + for i in range(3, len(nodelist), 2): + names.append(self.com_import_as_name(nodelist[i])) + n = From(self.com_dotted_name(nodelist[1]), names) + n.lineno = nodelist[0][2] + return n + + if nodelist[1][0] == symbol.dotted_name: + names = [(self.com_dotted_name(nodelist[1][1:]), None)] + else: + names = [] + for i in range(1, len(nodelist), 2): + names.append(self.com_dotted_as_name(nodelist[i])) + n = Import(names) + n.lineno = nodelist[0][2] + return n + + def global_stmt(self, nodelist): + # global: NAME (',' NAME)* + names = [] + for i in range(1, len(nodelist), 2): + names.append(nodelist[i][1]) + n = Global(names) + n.lineno = nodelist[0][2] + return n + + def exec_stmt(self, nodelist): + # exec_stmt: 'exec' expr ['in' expr [',' expr]] + expr1 = self.com_node(nodelist[1]) + if len(nodelist) >= 4: + expr2 = self.com_node(nodelist[3]) + if len(nodelist) >= 6: + expr3 = self.com_node(nodelist[5]) + else: + expr3 = None + else: + expr2 = expr3 = None + + n = Exec(expr1, expr2, expr3) + n.lineno = nodelist[0][2] + return n + + def assert_stmt(self, nodelist): + # 'assert': test, [',' test] + expr1 = self.com_node(nodelist[1]) + if (len(nodelist) == 4): + expr2 = self.com_node(nodelist[3]) + else: + expr2 = None + n = Assert(expr1, expr2) + n.lineno = nodelist[0][2] + return n + + def if_stmt(self, nodelist): + # if: test ':' suite ('elif' test ':' suite)* ['else' ':' suite] + tests = [] + for i in range(0, len(nodelist) - 3, 4): + testNode = self.com_node(nodelist[i + 1]) + suiteNode = self.com_node(nodelist[i + 3]) + tests.append((testNode, suiteNode)) + + if len(nodelist) % 4 == 3: + elseNode = self.com_node(nodelist[-1]) +## elseNode.lineno = nodelist[-1][1][2] + else: + elseNode = None + n = If(tests, elseNode) + n.lineno = nodelist[0][2] + return n + + def while_stmt(self, nodelist): + # 'while' test ':' suite ['else' ':' suite] + + testNode = self.com_node(nodelist[1]) + bodyNode = self.com_node(nodelist[3]) + + if len(nodelist) > 4: + elseNode = self.com_node(nodelist[6]) + else: + elseNode = None + + n = While(testNode, bodyNode, elseNode) + n.lineno = nodelist[0][2] + return n + + def for_stmt(self, nodelist): + # 'for' exprlist 'in' exprlist ':' suite ['else' ':' suite] + + assignNode = self.com_assign(nodelist[1], OP_ASSIGN) + listNode = self.com_node(nodelist[3]) + bodyNode = self.com_node(nodelist[5]) + + if len(nodelist) > 8: + elseNode = self.com_node(nodelist[8]) + else: + elseNode = None + + n = For(assignNode, listNode, bodyNode, elseNode) + n.lineno = nodelist[0][2] + return n + + def try_stmt(self, nodelist): + # 'try' ':' suite (except_clause ':' suite)+ ['else' ':' suite] + # | 'try' ':' s... [truncated message content] |
From: <pj...@us...> - 2008-02-18 19:01:59
|
Revision: 4149 http://jython.svn.sourceforge.net/jython/?rev=4149&view=rev Author: pjenvey Date: 2008-02-18 11:01:57 -0800 (Mon, 18 Feb 2008) Log Message: ----------- subprocess from http://svn.python.org/projects/python/trunk@60896 Added Paths: ----------- trunk/jython/Lib/subprocess.py trunk/jython/Lib/test/test_subprocess.py Added: trunk/jython/Lib/subprocess.py =================================================================== --- trunk/jython/Lib/subprocess.py (rev 0) +++ trunk/jython/Lib/subprocess.py 2008-02-18 19:01:57 UTC (rev 4149) @@ -0,0 +1,1250 @@ +# subprocess - Subprocesses with accessible I/O streams +# +# For more information about this module, see PEP 324. +# +# This module should remain compatible with Python 2.2, see PEP 291. +# +# Copyright (c) 2003-2005 by Peter Astrand <as...@ly...> +# +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/2.4/license for licensing details. + +r"""subprocess - Subprocesses with accessible I/O streams + +This module allows you to spawn processes, connect to their +input/output/error pipes, and obtain their return codes. This module +intends to replace several other, older modules and functions, like: + +os.system +os.spawn* +os.popen* +popen2.* +commands.* + +Information about how the subprocess module can be used to replace these +modules and functions can be found below. + + + +Using the subprocess module +=========================== +This module defines one class called Popen: + +class Popen(args, bufsize=0, executable=None, + stdin=None, stdout=None, stderr=None, + preexec_fn=None, close_fds=False, shell=False, + cwd=None, env=None, universal_newlines=False, + startupinfo=None, creationflags=0): + + +Arguments are: + +args should be a string, or a sequence of program arguments. The +program to execute is normally the first item in the args sequence or +string, but can be explicitly set by using the executable argument. + +On UNIX, with shell=False (default): In this case, the Popen class +uses os.execvp() to execute the child program. args should normally +be a sequence. A string will be treated as a sequence with the string +as the only item (the program to execute). + +On UNIX, with shell=True: If args is a string, it specifies the +command string to execute through the shell. If args is a sequence, +the first item specifies the command string, and any additional items +will be treated as additional shell arguments. + +On Windows: the Popen class uses CreateProcess() to execute the child +program, which operates on strings. If args is a sequence, it will be +converted to a string using the list2cmdline method. Please note that +not all MS Windows applications interpret the command line the same +way: The list2cmdline is designed for applications using the same +rules as the MS C runtime. + +bufsize, if given, has the same meaning as the corresponding argument +to the built-in open() function: 0 means unbuffered, 1 means line +buffered, any other positive value means use a buffer of +(approximately) that size. A negative bufsize means to use the system +default, which usually means fully buffered. The default value for +bufsize is 0 (unbuffered). + +stdin, stdout and stderr specify the executed programs' standard +input, standard output and standard error file handles, respectively. +Valid values are PIPE, an existing file descriptor (a positive +integer), an existing file object, and None. PIPE indicates that a +new pipe to the child should be created. With None, no redirection +will occur; the child's file handles will be inherited from the +parent. Additionally, stderr can be STDOUT, which indicates that the +stderr data from the applications should be captured into the same +file handle as for stdout. + +If preexec_fn is set to a callable object, this object will be called +in the child process just before the child is executed. + +If close_fds is true, all file descriptors except 0, 1 and 2 will be +closed before the child process is executed. + +if shell is true, the specified command will be executed through the +shell. + +If cwd is not None, the current directory will be changed to cwd +before the child is executed. + +If env is not None, it defines the environment variables for the new +process. + +If universal_newlines is true, the file objects stdout and stderr are +opened as a text files, but lines may be terminated by any of '\n', +the Unix end-of-line convention, '\r', the Macintosh convention or +'\r\n', the Windows convention. All of these external representations +are seen as '\n' by the Python program. Note: This feature is only +available if Python is built with universal newline support (the +default). Also, the newlines attribute of the file objects stdout, +stdin and stderr are not updated by the communicate() method. + +The startupinfo and creationflags, if given, will be passed to the +underlying CreateProcess() function. They can specify things such as +appearance of the main window and priority for the new process. +(Windows only) + + +This module also defines two shortcut functions: + +call(*popenargs, **kwargs): + Run command with arguments. Wait for command to complete, then + return the returncode attribute. + + The arguments are the same as for the Popen constructor. Example: + + retcode = call(["ls", "-l"]) + +check_call(*popenargs, **kwargs): + Run command with arguments. Wait for command to complete. If the + exit code was zero then return, otherwise raise + CalledProcessError. The CalledProcessError object will have the + return code in the returncode attribute. + + The arguments are the same as for the Popen constructor. Example: + + check_call(["ls", "-l"]) + +Exceptions +---------- +Exceptions raised in the child process, before the new program has +started to execute, will be re-raised in the parent. Additionally, +the exception object will have one extra attribute called +'child_traceback', which is a string containing traceback information +from the childs point of view. + +The most common exception raised is OSError. This occurs, for +example, when trying to execute a non-existent file. Applications +should prepare for OSErrors. + +A ValueError will be raised if Popen is called with invalid arguments. + +check_call() will raise CalledProcessError, if the called process +returns a non-zero return code. + + +Security +-------- +Unlike some other popen functions, this implementation will never call +/bin/sh implicitly. This means that all characters, including shell +metacharacters, can safely be passed to child processes. + + +Popen objects +============= +Instances of the Popen class have the following methods: + +poll() + Check if child process has terminated. Returns returncode + attribute. + +wait() + Wait for child process to terminate. Returns returncode attribute. + +communicate(input=None) + Interact with process: Send data to stdin. Read data from stdout + and stderr, until end-of-file is reached. Wait for process to + terminate. The optional input argument should be a string to be + sent to the child process, or None, if no data should be sent to + the child. + + communicate() returns a tuple (stdout, stderr). + + Note: The data read is buffered in memory, so do not use this + method if the data size is large or unlimited. + +The following attributes are also available: + +stdin + If the stdin argument is PIPE, this attribute is a file object + that provides input to the child process. Otherwise, it is None. + +stdout + If the stdout argument is PIPE, this attribute is a file object + that provides output from the child process. Otherwise, it is + None. + +stderr + If the stderr argument is PIPE, this attribute is file object that + provides error output from the child process. Otherwise, it is + None. + +pid + The process ID of the child process. + +returncode + The child return code. A None value indicates that the process + hasn't terminated yet. A negative value -N indicates that the + child was terminated by signal N (UNIX only). + + +Replacing older functions with the subprocess module +==================================================== +In this section, "a ==> b" means that b can be used as a replacement +for a. + +Note: All functions in this section fail (more or less) silently if +the executed program cannot be found; this module raises an OSError +exception. + +In the following examples, we assume that the subprocess module is +imported with "from subprocess import *". + + +Replacing /bin/sh shell backquote +--------------------------------- +output=`mycmd myarg` +==> +output = Popen(["mycmd", "myarg"], stdout=PIPE).communicate()[0] + + +Replacing shell pipe line +------------------------- +output=`dmesg | grep hda` +==> +p1 = Popen(["dmesg"], stdout=PIPE) +p2 = Popen(["grep", "hda"], stdin=p1.stdout, stdout=PIPE) +output = p2.communicate()[0] + + +Replacing os.system() +--------------------- +sts = os.system("mycmd" + " myarg") +==> +p = Popen("mycmd" + " myarg", shell=True) +pid, sts = os.waitpid(p.pid, 0) + +Note: + +* Calling the program through the shell is usually not required. + +* It's easier to look at the returncode attribute than the + exitstatus. + +A more real-world example would look like this: + +try: + retcode = call("mycmd" + " myarg", shell=True) + if retcode < 0: + print >>sys.stderr, "Child was terminated by signal", -retcode + else: + print >>sys.stderr, "Child returned", retcode +except OSError, e: + print >>sys.stderr, "Execution failed:", e + + +Replacing os.spawn* +------------------- +P_NOWAIT example: + +pid = os.spawnlp(os.P_NOWAIT, "/bin/mycmd", "mycmd", "myarg") +==> +pid = Popen(["/bin/mycmd", "myarg"]).pid + + +P_WAIT example: + +retcode = os.spawnlp(os.P_WAIT, "/bin/mycmd", "mycmd", "myarg") +==> +retcode = call(["/bin/mycmd", "myarg"]) + + +Vector example: + +os.spawnvp(os.P_NOWAIT, path, args) +==> +Popen([path] + args[1:]) + + +Environment example: + +os.spawnlpe(os.P_NOWAIT, "/bin/mycmd", "mycmd", "myarg", env) +==> +Popen(["/bin/mycmd", "myarg"], env={"PATH": "/usr/bin"}) + + +Replacing os.popen* +------------------- +pipe = os.popen(cmd, mode='r', bufsize) +==> +pipe = Popen(cmd, shell=True, bufsize=bufsize, stdout=PIPE).stdout + +pipe = os.popen(cmd, mode='w', bufsize) +==> +pipe = Popen(cmd, shell=True, bufsize=bufsize, stdin=PIPE).stdin + + +(child_stdin, child_stdout) = os.popen2(cmd, mode, bufsize) +==> +p = Popen(cmd, shell=True, bufsize=bufsize, + stdin=PIPE, stdout=PIPE, close_fds=True) +(child_stdin, child_stdout) = (p.stdin, p.stdout) + + +(child_stdin, + child_stdout, + child_stderr) = os.popen3(cmd, mode, bufsize) +==> +p = Popen(cmd, shell=True, bufsize=bufsize, + stdin=PIPE, stdout=PIPE, stderr=PIPE, close_fds=True) +(child_stdin, + child_stdout, + child_stderr) = (p.stdin, p.stdout, p.stderr) + + +(child_stdin, child_stdout_and_stderr) = os.popen4(cmd, mode, bufsize) +==> +p = Popen(cmd, shell=True, bufsize=bufsize, + stdin=PIPE, stdout=PIPE, stderr=STDOUT, close_fds=True) +(child_stdin, child_stdout_and_stderr) = (p.stdin, p.stdout) + + +Replacing popen2.* +------------------ +Note: If the cmd argument to popen2 functions is a string, the command +is executed through /bin/sh. If it is a list, the command is directly +executed. + +(child_stdout, child_stdin) = popen2.popen2("somestring", bufsize, mode) +==> +p = Popen(["somestring"], shell=True, bufsize=bufsize + stdin=PIPE, stdout=PIPE, close_fds=True) +(child_stdout, child_stdin) = (p.stdout, p.stdin) + + +(child_stdout, child_stdin) = popen2.popen2(["mycmd", "myarg"], bufsize, mode) +==> +p = Popen(["mycmd", "myarg"], bufsize=bufsize, + stdin=PIPE, stdout=PIPE, close_fds=True) +(child_stdout, child_stdin) = (p.stdout, p.stdin) + +The popen2.Popen3 and popen2.Popen4 basically works as subprocess.Popen, +except that: + +* subprocess.Popen raises an exception if the execution fails +* the capturestderr argument is replaced with the stderr argument. +* stdin=PIPE and stdout=PIPE must be specified. +* popen2 closes all filedescriptors by default, but you have to specify + close_fds=True with subprocess.Popen. +""" + +import sys +mswindows = (sys.platform == "win32") + +import os +import types +import traceback +import gc + +# Exception classes used by this module. +class CalledProcessError(Exception): + """This exception is raised when a process run by check_call() returns + a non-zero exit status. The exit status will be stored in the + returncode attribute.""" + def __init__(self, returncode, cmd): + self.returncode = returncode + self.cmd = cmd + def __str__(self): + return "Command '%s' returned non-zero exit status %d" % (self.cmd, self.returncode) + + +if mswindows: + import threading + import msvcrt + if 0: # <-- change this to use pywin32 instead of the _subprocess driver + import pywintypes + from win32api import GetStdHandle, STD_INPUT_HANDLE, \ + STD_OUTPUT_HANDLE, STD_ERROR_HANDLE + from win32api import GetCurrentProcess, DuplicateHandle, \ + GetModuleFileName, GetVersion + from win32con import DUPLICATE_SAME_ACCESS, SW_HIDE + from win32pipe import CreatePipe + from win32process import CreateProcess, STARTUPINFO, \ + GetExitCodeProcess, STARTF_USESTDHANDLES, \ + STARTF_USESHOWWINDOW, CREATE_NEW_CONSOLE + from win32event import WaitForSingleObject, INFINITE, WAIT_OBJECT_0 + else: + from _subprocess import * + class STARTUPINFO: + dwFlags = 0 + hStdInput = None + hStdOutput = None + hStdError = None + wShowWindow = 0 + class pywintypes: + error = IOError +else: + import select + import errno + import fcntl + import pickle + +__all__ = ["Popen", "PIPE", "STDOUT", "call", "check_call", "CalledProcessError"] + +try: + MAXFD = os.sysconf("SC_OPEN_MAX") +except: + MAXFD = 256 + +# True/False does not exist on 2.2.0 +try: + False +except NameError: + False = 0 + True = 1 + +_active = [] + +def _cleanup(): + for inst in _active[:]: + if inst.poll(_deadstate=sys.maxint) >= 0: + try: + _active.remove(inst) + except ValueError: + # This can happen if two threads create a new Popen instance. + # It's harmless that it was already removed, so ignore. + pass + +PIPE = -1 +STDOUT = -2 + + +def call(*popenargs, **kwargs): + """Run command with arguments. Wait for command to complete, then + return the returncode attribute. + + The arguments are the same as for the Popen constructor. Example: + + retcode = call(["ls", "-l"]) + """ + return Popen(*popenargs, **kwargs).wait() + + +def check_call(*popenargs, **kwargs): + """Run command with arguments. Wait for command to complete. If + the exit code was zero then return, otherwise raise + CalledProcessError. The CalledProcessError object will have the + return code in the returncode attribute. + + The arguments are the same as for the Popen constructor. Example: + + check_call(["ls", "-l"]) + """ + retcode = call(*popenargs, **kwargs) + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + if retcode: + raise CalledProcessError(retcode, cmd) + return retcode + + +def list2cmdline(seq): + """ + Translate a sequence of arguments into a command line + string, using the same rules as the MS C runtime: + + 1) Arguments are delimited by white space, which is either a + space or a tab. + + 2) A string surrounded by double quotation marks is + interpreted as a single argument, regardless of white space + or pipe characters contained within. A quoted string can be + embedded in an argument. + + 3) A double quotation mark preceded by a backslash is + interpreted as a literal double quotation mark. + + 4) Backslashes are interpreted literally, unless they + immediately precede a double quotation mark. + + 5) If backslashes immediately precede a double quotation mark, + every pair of backslashes is interpreted as a literal + backslash. If the number of backslashes is odd, the last + backslash escapes the next double quotation mark as + described in rule 3. + """ + + # See + # http://msdn.microsoft.com/library/en-us/vccelng/htm/progs_12.asp + result = [] + needquote = False + for arg in seq: + bs_buf = [] + + # Add a space to separate this argument from the others + if result: + result.append(' ') + + needquote = (" " in arg) or ("\t" in arg) or ("|" in arg) or not arg + if needquote: + result.append('"') + + for c in arg: + if c == '\\': + # Don't know if we need to double yet. + bs_buf.append(c) + elif c == '"': + # Double backslashes. + result.append('\\' * len(bs_buf)*2) + bs_buf = [] + result.append('\\"') + else: + # Normal char + if bs_buf: + result.extend(bs_buf) + bs_buf = [] + result.append(c) + + # Add remaining backslashes, if any. + if bs_buf: + result.extend(bs_buf) + + if needquote: + result.extend(bs_buf) + result.append('"') + + return ''.join(result) + + +class Popen(object): + def __init__(self, args, bufsize=0, executable=None, + stdin=None, stdout=None, stderr=None, + preexec_fn=None, close_fds=False, shell=False, + cwd=None, env=None, universal_newlines=False, + startupinfo=None, creationflags=0): + """Create new Popen instance.""" + _cleanup() + + self._child_created = False + if not isinstance(bufsize, (int, long)): + raise TypeError("bufsize must be an integer") + + if mswindows: + if preexec_fn is not None: + raise ValueError("preexec_fn is not supported on Windows " + "platforms") + if close_fds and (stdin is not None or stdout is not None or + stderr is not None): + raise ValueError("close_fds is not supported on Windows " + "platforms if you redirect stdin/stdout/stderr") + else: + # POSIX + if startupinfo is not None: + raise ValueError("startupinfo is only supported on Windows " + "platforms") + if creationflags != 0: + raise ValueError("creationflags is only supported on Windows " + "platforms") + + self.stdin = None + self.stdout = None + self.stderr = None + self.pid = None + self.returncode = None + self.universal_newlines = universal_newlines + + # Input and output objects. The general principle is like + # this: + # + # Parent Child + # ------ ----- + # p2cwrite ---stdin---> p2cread + # c2pread <--stdout--- c2pwrite + # errread <--stderr--- errwrite + # + # On POSIX, the child objects are file descriptors. On + # Windows, these are Windows file handles. The parent objects + # are file descriptors on both platforms. The parent objects + # are None when not using PIPEs. The child objects are None + # when not redirecting. + + (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) = self._get_handles(stdin, stdout, stderr) + + self._execute_child(args, executable, preexec_fn, close_fds, + cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) + + # On Windows, you cannot just redirect one or two handles: You + # either have to redirect all three or none. If the subprocess + # user has only redirected one or two handles, we are + # automatically creating PIPEs for the rest. We should close + # these after the process is started. See bug #1124861. + if mswindows: + if stdin is None and p2cwrite is not None: + os.close(p2cwrite) + p2cwrite = None + if stdout is None and c2pread is not None: + os.close(c2pread) + c2pread = None + if stderr is None and errread is not None: + os.close(errread) + errread = None + + if p2cwrite is not None: + self.stdin = os.fdopen(p2cwrite, 'wb', bufsize) + if c2pread is not None: + if universal_newlines: + self.stdout = os.fdopen(c2pread, 'rU', bufsize) + else: + self.stdout = os.fdopen(c2pread, 'rb', bufsize) + if errread is not None: + if universal_newlines: + self.stderr = os.fdopen(errread, 'rU', bufsize) + else: + self.stderr = os.fdopen(errread, 'rb', bufsize) + + + def _translate_newlines(self, data): + data = data.replace("\r\n", "\n") + data = data.replace("\r", "\n") + return data + + + def __del__(self, sys=sys): + if not self._child_created: + # We didn't get to successfully create a child process. + return + # In case the child hasn't been waited on, check if it's done. + self.poll(_deadstate=sys.maxint) + if self.returncode is None and _active is not None: + # Child is still running, keep us alive until we can wait on it. + _active.append(self) + + + def communicate(self, input=None): + """Interact with process: Send data to stdin. Read data from + stdout and stderr, until end-of-file is reached. Wait for + process to terminate. The optional input argument should be a + string to be sent to the child process, or None, if no data + should be sent to the child. + + communicate() returns a tuple (stdout, stderr).""" + + # Optimization: If we are only using one pipe, or no pipe at + # all, using select() or threads is unnecessary. + if [self.stdin, self.stdout, self.stderr].count(None) >= 2: + stdout = None + stderr = None + if self.stdin: + if input: + self.stdin.write(input) + self.stdin.close() + elif self.stdout: + stdout = self.stdout.read() + elif self.stderr: + stderr = self.stderr.read() + self.wait() + return (stdout, stderr) + + return self._communicate(input) + + + if mswindows: + # + # Windows methods + # + def _get_handles(self, stdin, stdout, stderr): + """Construct and return tupel with IO objects: + p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite + """ + if stdin is None and stdout is None and stderr is None: + return (None, None, None, None, None, None) + + p2cread, p2cwrite = None, None + c2pread, c2pwrite = None, None + errread, errwrite = None, None + + if stdin is None: + p2cread = GetStdHandle(STD_INPUT_HANDLE) + if p2cread is not None: + pass + elif stdin is None or stdin == PIPE: + p2cread, p2cwrite = CreatePipe(None, 0) + # Detach and turn into fd + p2cwrite = p2cwrite.Detach() + p2cwrite = msvcrt.open_osfhandle(p2cwrite, 0) + elif isinstance(stdin, int): + p2cread = msvcrt.get_osfhandle(stdin) + else: + # Assuming file-like object + p2cread = msvcrt.get_osfhandle(stdin.fileno()) + p2cread = self._make_inheritable(p2cread) + + if stdout is None: + c2pwrite = GetStdHandle(STD_OUTPUT_HANDLE) + if c2pwrite is not None: + pass + elif stdout is None or stdout == PIPE: + c2pread, c2pwrite = CreatePipe(None, 0) + # Detach and turn into fd + c2pread = c2pread.Detach() + c2pread = msvcrt.open_osfhandle(c2pread, 0) + elif isinstance(stdout, int): + c2pwrite = msvcrt.get_osfhandle(stdout) + else: + # Assuming file-like object + c2pwrite = msvcrt.get_osfhandle(stdout.fileno()) + c2pwrite = self._make_inheritable(c2pwrite) + + if stderr is None: + errwrite = GetStdHandle(STD_ERROR_HANDLE) + if errwrite is not None: + pass + elif stderr is None or stderr == PIPE: + errread, errwrite = CreatePipe(None, 0) + # Detach and turn into fd + errread = errread.Detach() + errread = msvcrt.open_osfhandle(errread, 0) + elif stderr == STDOUT: + errwrite = c2pwrite + elif isinstance(stderr, int): + errwrite = msvcrt.get_osfhandle(stderr) + else: + # Assuming file-like object + errwrite = msvcrt.get_osfhandle(stderr.fileno()) + errwrite = self._make_inheritable(errwrite) + + return (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) + + + def _make_inheritable(self, handle): + """Return a duplicate of handle, which is inheritable""" + return DuplicateHandle(GetCurrentProcess(), handle, + GetCurrentProcess(), 0, 1, + DUPLICATE_SAME_ACCESS) + + + def _find_w9xpopen(self): + """Find and return absolut path to w9xpopen.exe""" + w9xpopen = os.path.join(os.path.dirname(GetModuleFileName(0)), + "w9xpopen.exe") + if not os.path.exists(w9xpopen): + # Eeek - file-not-found - possibly an embedding + # situation - see if we can locate it in sys.exec_prefix + w9xpopen = os.path.join(os.path.dirname(sys.exec_prefix), + "w9xpopen.exe") + if not os.path.exists(w9xpopen): + raise RuntimeError("Cannot locate w9xpopen.exe, which is " + "needed for Popen to work with your " + "shell or platform.") + return w9xpopen + + + def _execute_child(self, args, executable, preexec_fn, close_fds, + cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite): + """Execute program (MS Windows version)""" + + if not isinstance(args, types.StringTypes): + args = list2cmdline(args) + + # Process startup details + if startupinfo is None: + startupinfo = STARTUPINFO() + if None not in (p2cread, c2pwrite, errwrite): + startupinfo.dwFlags |= STARTF_USESTDHANDLES + startupinfo.hStdInput = p2cread + startupinfo.hStdOutput = c2pwrite + startupinfo.hStdError = errwrite + + if shell: + startupinfo.dwFlags |= STARTF_USESHOWWINDOW + startupinfo.wShowWindow = SW_HIDE + comspec = os.environ.get("COMSPEC", "cmd.exe") + args = comspec + " /c " + args + if (GetVersion() >= 0x80000000L or + os.path.basename(comspec).lower() == "command.com"): + # Win9x, or using command.com on NT. We need to + # use the w9xpopen intermediate program. For more + # information, see KB Q150956 + # (http://web.archive.org/web/20011105084002/http://support.microsoft.com/support/kb/articles/Q150/9/56.asp) + w9xpopen = self._find_w9xpopen() + args = '"%s" %s' % (w9xpopen, args) + # Not passing CREATE_NEW_CONSOLE has been known to + # cause random failures on win9x. Specifically a + # dialog: "Your program accessed mem currently in + # use at xxx" and a hopeful warning about the + # stability of your system. Cost is Ctrl+C wont + # kill children. + creationflags |= CREATE_NEW_CONSOLE + + # Start the process + try: + hp, ht, pid, tid = CreateProcess(executable, args, + # no special security + None, None, + int(not close_fds), + creationflags, + env, + cwd, + startupinfo) + except pywintypes.error, e: + # Translate pywintypes.error to WindowsError, which is + # a subclass of OSError. FIXME: We should really + # translate errno using _sys_errlist (or simliar), but + # how can this be done from Python? + raise WindowsError(*e.args) + + # Retain the process handle, but close the thread handle + self._child_created = True + self._handle = hp + self.pid = pid + ht.Close() + + # Child is launched. Close the parent's copy of those pipe + # handles that only the child should have open. You need + # to make sure that no handles to the write end of the + # output pipe are maintained in this process or else the + # pipe will not close when the child process exits and the + # ReadFile will hang. + if p2cread is not None: + p2cread.Close() + if c2pwrite is not None: + c2pwrite.Close() + if errwrite is not None: + errwrite.Close() + + + def poll(self, _deadstate=None): + """Check if child process has terminated. Returns returncode + attribute.""" + if self.returncode is None: + if WaitForSingleObject(self._handle, 0) == WAIT_OBJECT_0: + self.returncode = GetExitCodeProcess(self._handle) + return self.returncode + + + def wait(self): + """Wait for child process to terminate. Returns returncode + attribute.""" + if self.returncode is None: + obj = WaitForSingleObject(self._handle, INFINITE) + self.returncode = GetExitCodeProcess(self._handle) + return self.returncode + + + def _readerthread(self, fh, buffer): + buffer.append(fh.read()) + + + def _communicate(self, input): + stdout = None # Return + stderr = None # Return + + if self.stdout: + stdout = [] + stdout_thread = threading.Thread(target=self._readerthread, + args=(self.stdout, stdout)) + stdout_thread.setDaemon(True) + stdout_thread.start() + if self.stderr: + stderr = [] + stderr_thread = threading.Thread(target=self._readerthread, + args=(self.stderr, stderr)) + stderr_thread.setDaemon(True) + stderr_thread.start() + + if self.stdin: + if input is not None: + self.stdin.write(input) + self.stdin.close() + + if self.stdout: + stdout_thread.join() + if self.stderr: + stderr_thread.join() + + # All data exchanged. Translate lists into strings. + if stdout is not None: + stdout = stdout[0] + if stderr is not None: + stderr = stderr[0] + + # Translate newlines, if requested. We cannot let the file + # object do the translation: It is based on stdio, which is + # impossible to combine with select (unless forcing no + # buffering). + if self.universal_newlines and hasattr(file, 'newlines'): + if stdout: + stdout = self._translate_newlines(stdout) + if stderr: + stderr = self._translate_newlines(stderr) + + self.wait() + return (stdout, stderr) + + else: + # + # POSIX methods + # + def _get_handles(self, stdin, stdout, stderr): + """Construct and return tupel with IO objects: + p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite + """ + p2cread, p2cwrite = None, None + c2pread, c2pwrite = None, None + errread, errwrite = None, None + + if stdin is None: + pass + elif stdin == PIPE: + p2cread, p2cwrite = os.pipe() + elif isinstance(stdin, int): + p2cread = stdin + else: + # Assuming file-like object + p2cread = stdin.fileno() + + if stdout is None: + pass + elif stdout == PIPE: + c2pread, c2pwrite = os.pipe() + elif isinstance(stdout, int): + c2pwrite = stdout + else: + # Assuming file-like object + c2pwrite = stdout.fileno() + + if stderr is None: + pass + elif stderr == PIPE: + errread, errwrite = os.pipe() + elif stderr == STDOUT: + errwrite = c2pwrite + elif isinstance(stderr, int): + errwrite = stderr + else: + # Assuming file-like object + errwrite = stderr.fileno() + + return (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) + + + def _set_cloexec_flag(self, fd): + try: + cloexec_flag = fcntl.FD_CLOEXEC + except AttributeError: + cloexec_flag = 1 + + old = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag) + + + def _close_fds(self, but): + os.closerange(3, but) + os.closerange(but + 1, MAXFD) + + + def _execute_child(self, args, executable, preexec_fn, close_fds, + cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite): + """Execute program (POSIX version)""" + + if isinstance(args, types.StringTypes): + args = [args] + else: + args = list(args) + + if shell: + args = ["/bin/sh", "-c"] + args + + if executable is None: + executable = args[0] + + # For transferring possible exec failure from child to parent + # The first char specifies the exception type: 0 means + # OSError, 1 means some other error. + errpipe_read, errpipe_write = os.pipe() + self._set_cloexec_flag(errpipe_write) + + gc_was_enabled = gc.isenabled() + # Disable gc to avoid bug where gc -> file_dealloc -> + # write to stderr -> hang. http://bugs.python.org/issue1336 + gc.disable() + try: + self.pid = os.fork() + except: + if gc_was_enabled: + gc.enable() + raise + self._child_created = True + if self.pid == 0: + # Child + try: + # Close parent's pipe ends + if p2cwrite is not None: + os.close(p2cwrite) + if c2pread is not None: + os.close(c2pread) + if errread is not None: + os.close(errread) + os.close(errpipe_read) + + # Dup fds for child + if p2cread is not None: + os.dup2(p2cread, 0) + if c2pwrite is not None: + os.dup2(c2pwrite, 1) + if errwrite is not None: + os.dup2(errwrite, 2) + + # Close pipe fds. Make sure we don't close the same + # fd more than once, or standard fds. + if p2cread is not None and p2cread not in (0,): + os.close(p2cread) + if c2pwrite is not None and c2pwrite not in (p2cread, 1): + os.close(c2pwrite) + if errwrite is not None and errwrite not in (p2cread, c2pwrite, 2): + os.close(errwrite) + + # Close all other fds, if asked for + if close_fds: + self._close_fds(but=errpipe_write) + + if cwd is not None: + os.chdir(cwd) + + if preexec_fn: + apply(preexec_fn) + + if env is None: + os.execvp(executable, args) + else: + os.execvpe(executable, args, env) + + except: + exc_type, exc_value, tb = sys.exc_info() + # Save the traceback and attach it to the exception object + exc_lines = traceback.format_exception(exc_type, + exc_value, + tb) + exc_value.child_traceback = ''.join(exc_lines) + os.write(errpipe_write, pickle.dumps(exc_value)) + + # This exitcode won't be reported to applications, so it + # really doesn't matter what we return. + os._exit(255) + + # Parent + if gc_was_enabled: + gc.enable() + os.close(errpipe_write) + if p2cread is not None and p2cwrite is not None: + os.close(p2cread) + if c2pwrite is not None and c2pread is not None: + os.close(c2pwrite) + if errwrite is not None and errread is not None: + os.close(errwrite) + + # Wait for exec to fail or succeed; possibly raising exception + data = os.read(errpipe_read, 1048576) # Exceptions limited to 1 MB + os.close(errpipe_read) + if data != "": + os.waitpid(self.pid, 0) + child_exception = pickle.loads(data) + raise child_exception + + + def _handle_exitstatus(self, sts): + if os.WIFSIGNALED(sts): + self.returncode = -os.WTERMSIG(sts) + elif os.WIFEXITED(sts): + self.returncode = os.WEXITSTATUS(sts) + else: + # Should never happen + raise RuntimeError("Unknown child exit status!") + + + def poll(self, _deadstate=None): + """Check if child process has terminated. Returns returncode + attribute.""" + if self.returncode is None: + try: + pid, sts = os.waitpid(self.pid, os.WNOHANG) + if pid == self.pid: + self._handle_exitstatus(sts) + except os.error: + if _deadstate is not None: + self.returncode = _deadstate + return self.returncode + + + def wait(self): + """Wait for child process to terminate. Returns returncode + attribute.""" + if self.returncode is None: + pid, sts = os.waitpid(self.pid, 0) + self._handle_exitstatus(sts) + return self.returncode + + + def _communicate(self, input): + read_set = [] + write_set = [] + stdout = None # Return + stderr = None # Return + + if self.stdin: + # Flush stdio buffer. This might block, if the user has + # been writing to .stdin in an uncontrolled fashion. + self.stdin.flush() + if input: + write_set.append(self.stdin) + else: + self.stdin.close() + if self.stdout: + read_set.append(self.stdout) + stdout = [] + if self.stderr: + read_set.append(self.stderr) + stderr = [] + + input_offset = 0 + while read_set or write_set: + rlist, wlist, xlist = select.select(read_set, write_set, []) + + if self.stdin in wlist: + # When select has indicated that the file is writable, + # we can write up to PIPE_BUF bytes without risk + # blocking. POSIX defines PIPE_BUF >= 512 + bytes_written = os.write(self.stdin.fileno(), buffer(input, input_offset, 512)) + input_offset += bytes_written + if input_offset >= len(input): + self.stdin.close() + write_set.remove(self.stdin) + + if self.stdout in rlist: + data = os.read(self.stdout.fileno(), 1024) + if data == "": + self.stdout.close() + read_set.remove(self.stdout) + stdout.append(data) + + if self.stderr in rlist: + data = os.read(self.stderr.fileno(), 1024) + if data == "": + self.stderr.close() + read_set.remove(self.stderr) + stderr.append(data) + + # All data exchanged. Translate lists into strings. + if stdout is not None: + stdout = ''.join(stdout) + if stderr is not None: + stderr = ''.join(stderr) + + # Translate newlines, if requested. We cannot let the file + # object do the translation: It is based on stdio, which is + # impossible to combine with select (unless forcing no + # buffering). + if self.universal_newlines and hasattr(file, 'newlines'): + if stdout: + stdout = self._translate_newlines(stdout) + if stderr: + stderr = self._translate_newlines(stderr) + + self.wait() + return (stdout, stderr) + + +def _demo_posix(): + # + # Example 1: Simple redirection: Get process list + # + plist = Popen(["ps"], stdout=PIPE).communicate()[0] + print "Process list:" + print plist + + # + # Example 2: Change uid before executing child + # + if os.getuid() == 0: + p = Popen(["id"], preexec_fn=lambda: os.setuid(100)) + p.wait() + + # + # Example 3: Connecting several subprocesses + # + print "Looking for 'hda'..." + p1 = Popen(["dmesg"], stdout=PIPE) + p2 = Popen(["grep", "hda"], stdin=p1.stdout, stdout=PIPE) + print repr(p2.communicate()[0]) + + # + # Example 4: Catch execution error + # + print + print "Trying a weird file..." + try: + print Popen(["/this/path/does/not/exist"]).communicate() + except OSError, e: + if e.errno == errno.ENOENT: + print "The file didn't exist. I thought so..." + print "Child traceback:" + print e.child_traceback + else: + print "Error", e.errno + else: + print >>sys.stderr, "Gosh. No error." + + +def _demo_windows(): + # + # Example 1: Connecting several subprocesses + # + print "Looking for 'PROMPT' in set output..." + p1 = Popen("set", stdout=PIPE, shell=True) + p2 = Popen('find "PROMPT"', stdin=p1.stdout, stdout=PIPE) + print repr(p2.communicate()[0]) + + # + # Example 2: Simple execution of program + # + print "Executing calc..." + p = Popen("calc") + p.wait() + + +if __name__ == "__main__": + if mswindows: + _demo_windows() + else: + _demo_posix() Added: trunk/jython/Lib/test/test_subprocess.py =================================================================== --- trunk/jython/Lib/test/test_subprocess.py (rev 0) +++ trunk/jython/Lib/test/test_subprocess.py 2008-02-18 19:01:57 UTC (rev 4149) @@ -0,0 +1,665 @@ +import unittest +from test import test_support +import subprocess +import sys +import signal +import os +import tempfile +import time +import re + +mswindows = (sys.platform == "win32") + +# +# Depends on the following external programs: Python +# + +if mswindows: + SETBINARY = ('import msvcrt; msvcrt.setmode(sys.stdout.fileno(), ' + 'os.O_BINARY);') +else: + SETBINARY = '' + +# In a debug build, stuff like "[6580 refs]" is printed to stderr at +# shutdown time. That frustrates tests trying to check stderr produced +# from a spawned Python process. +def remove_stderr_debug_decorations(stderr): + return re.sub(r"\[\d+ refs\]\r?\n?$", "", stderr) + +class ProcessTestCase(unittest.TestCase): + def setUp(self): + # Try to minimize the number of children we have so this test + # doesn't crash on some buildbots (Alphas in particular). + if hasattr(test_support, "reap_children"): + test_support.reap_children() + + def tearDown(self): + # Try to minimize the number of children we have so this test + # doesn't crash on some buildbots (Alphas in particular). + if hasattr(test_support, "reap_children"): + test_support.reap_children() + + def mkstemp(self): + """wrapper for mkstemp, calling mktemp if mkstemp is not available""" + if hasattr(tempfile, "mkstemp"): + return tempfile.mkstemp() + else: + fname = tempfile.mktemp() + return os.open(fname, os.O_RDWR|os.O_CREAT), fname + + # + # Generic tests + # + def test_call_seq(self): + # call() function with sequence argument + rc = subprocess.call([sys.executable, "-c", + "import sys; sys.exit(47)"]) + self.assertEqual(rc, 47) + + def test_check_call_zero(self): + # check_call() function with zero return code + rc = subprocess.check_call([sys.executable, "-c", + "import sys; sys.exit(0)"]) + self.assertEqual(rc, 0) + + def test_check_call_nonzero(self): + # check_call() function with non-zero return code + try: + subprocess.check_call([sys.executable, "-c", + "import sys; sys.exit(47)"]) + except subprocess.CalledProcessError, e: + self.assertEqual(e.returncode, 47) + else: + self.fail("Expected CalledProcessError") + + def test_call_kwargs(self): + # call() function with keyword args + newenv = os.environ.copy() + newenv["FRUIT"] = "banana" + rc = subprocess.call([sys.executable, "-c", + 'import sys, os;' \ + 'sys.exit(os.getenv("FRUIT")=="banana")'], + env=newenv) + self.assertEqual(rc, 1) + + def test_stdin_none(self): + # .stdin is None when not redirected + p = subprocess.Popen([sys.executable, "-c", 'print "banana"'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + self.assertEqual(p.stdin, None) + + def test_stdout_none(self): + # .stdout is None when not redirected + p = subprocess.Popen([sys.executable, "-c", + 'print " this bit of output is from a ' + 'test of stdout in a different ' + 'process ..."'], + stdin=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + self.assertEqual(p.stdout, None) + + def test_stderr_none(self): + # .stderr is None when not redirected + p = subprocess.Popen([sys.executable, "-c", 'print "banana"'], + stdin=subprocess.PIPE, stdout=subprocess.PIPE) + p.wait() + self.assertEqual(p.stderr, None) + + def test_executable(self): + p = subprocess.Popen(["somethingyoudonthave", + "-c", "import sys; sys.exit(47)"], + executable=sys.executable) + p.wait() + self.assertEqual(p.returncode, 47) + + def test_stdin_pipe(self): + # stdin redirection + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.exit(sys.stdin.read() == "pear")'], + stdin=subprocess.PIPE) + p.stdin.write("pear") + p.stdin.close() + p.wait() + self.assertEqual(p.returncode, 1) + + def test_stdin_filedes(self): + # stdin is set to open file descriptor + tf = tempfile.TemporaryFile() + d = tf.fileno() + os.write(d, "pear") + os.lseek(d, 0, 0) + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.exit(sys.stdin.read() == "pear")'], + stdin=d) + p.wait() + self.assertEqual(p.returncode, 1) + + def test_stdin_fileobj(self): + # stdin is set to open file object + tf = tempfile.TemporaryFile() + tf.write("pear") + tf.seek(0) + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.exit(sys.stdin.read() == "pear")'], + stdin=tf) + p.wait() + self.assertEqual(p.returncode, 1) + + def test_stdout_pipe(self): + # stdout redirection + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.stdout.write("orange")'], + stdout=subprocess.PIPE) + self.assertEqual(p.stdout.read(), "orange") + + def test_stdout_filedes(self): + # stdout is set to open file descriptor + tf = tempfile.TemporaryFile() + d = tf.fileno() + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.stdout.write("orange")'], + stdout=d) + p.wait() + os.lseek(d, 0, 0) + self.assertEqual(os.read(d, 1024), "orange") + + def test_stdout_fileobj(self): + # stdout is set to open file object + tf = tempfile.TemporaryFile() + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.stdout.write("orange")'], + stdout=tf) + p.wait() + tf.seek(0) + self.assertEqual(tf.read(), "orange") + + def test_stderr_pipe(self): + # stderr redirection + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.stderr.write("strawberry")'], + stderr=subprocess.PIPE) + self.assertEqual(remove_stderr_debug_decorations(p.stderr.read()), + "strawberry") + + def test_stderr_filedes(self): + # stderr is set to open file descriptor + tf = tempfile.TemporaryFile() + d = tf.fileno() + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.stderr.write("strawberry")'], + stderr=d) + p.wait() + os.lseek(d, 0, 0) + self.assertEqual(remove_stderr_debug_decorations(os.read(d, 1024)), + "strawberry") + + def test_stderr_fileobj(self): + # stderr is set to open file object + tf = tempfile.TemporaryFile() + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.stderr.write("strawberry")'], + stderr=tf) + p.wait() + tf.seek(0) + self.assertEqual(remove_stderr_debug_decorations(tf.read()), + "strawberry") + + def test_stdout_stderr_pipe(self): + # capture stdout and stderr to the same pipe + p = subprocess.Popen([sys.executable, "-c", + 'import sys;' \ + 'sys.stdout.write("apple");' \ + 'sys.stdout.flush();' \ + 'sys.stderr.write("orange")'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + output = p.stdout.read() + stripped = remove_stderr_debug_decorations(output) + self.assertEqual(stripped, "appleorange") + + def test_stdout_stderr_file(self): + # capture stdout and stderr to the same open file + tf = tempfile.TemporaryFile() + p = subprocess.Popen([sys.executable, "-c", + 'import sys;' \ + 'sys.stdout.write("apple");' \ + 'sys.stdout.flush();' \ + 'sys.stderr.write("orange")'], + stdout=tf, + stderr=tf) + p.wait() + tf.seek(0) + output = tf.read() + stripped = remove_stderr_debug_decorations(output) + self.assertEqual(stripped, "appleorange") + + def test_stdout_filedes_of_stdout(self): + # stdout is set to 1 (#1531862). + cmd = r"import sys, os; sys.exit(os.write(sys.stdout.fileno(), '.\n'))" + rc = subprocess.call([sys.executable, "-c", cmd], stdout=1) + self.assertEquals(rc, 2) + + def test_cwd(self): + tmpdir = tempfile.gettempdir() + # We cannot use os.path.realpath to canonicalize the path, + # since it doesn't expand Tru64 {memb} strings. See bug 1063571. + cwd = os.getcwd() + os.chdir(tmpdir) + tmpdir = os.getcwd() + os.chdir(cwd) + p = subprocess.Popen([sys.executable, "-c", + 'import sys,os;' \ + 'sys.stdout.write(os.getcwd())'], + stdout=subprocess.PIPE, + cwd=tmpdir) + normcase = os.path.normcase + self.assertEqual(normcase(p.stdout.read()), normcase(tmpdir)) + + def test_env(self): + newenv = os.environ.copy() + newenv["FRUIT"] = "orange" + p = subprocess.Popen([sys.executable, "-c", + 'import sys,os;' \ + 'sys.stdout.write(os.getenv("FRUIT"))'], + stdout=subprocess.PIPE, + env=newenv) + self.assertEqual(p.stdout.read(), "orange") + + def test_communicate_stdin(self): + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.exit(sys.stdin.read() == "pear")'], + stdin=subprocess.PIPE) + p.communicate("pear") + self.assertEqual(p.returncode, 1) + + def test_communicate_stdout(self): + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.stdout.write("pineapple")'], + stdout=subprocess.PIPE) + (stdout, stderr) = p.communicate() + self.assertEqual(stdout, "pineapple") + self.assertEqual(stderr, None) + + def test_communicate_stderr(self): + p = subprocess.Popen([sys.executable, "-c", + 'import sys; sys.stderr.write("pineapple")'], + stderr=subprocess.PIPE) + (stdout, stderr) = p.communicate() + self.assertEqual(stdout, None) + # When running with a pydebug build, the # of references is outputted + # to stderr, so just check if stderr at least started with "pinapple" + self.assert_(stderr.startswith("pineapple")) + + def test_communicate(self): + p = subprocess.Popen([sys.executable, "-c", + 'import sys,os;' \ + 'sys.stderr.write("pineapple");' \ + 'sys.stdout.write(sys.stdin.read())'], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + (stdout, stderr) = p.communicate("banana") + self.assertEqual(stdout, "banana") + self.assertEqual(remove_stderr_debug_decorations(stderr), + "pineapple") + + def test_communicate_returns(self): + # communicate() should return None if no redirection is active + p = subprocess.Popen([sys.executable, "-c", + "import sys; sys.exit(47)"]) + (stdout, stderr) = p.communicate() + self.assertEqual(stdout, None) + self.assertEqual(stderr, None) + + def test_communicate_pipe_buf(self): + # communicate() with writes larger than pipe_buf + # This test will probably deadlock rather than fail, if + # communicate() does not work properly. + x, y = os.pipe() + if mswindows: + pipe_buf = 512 + else: + pipe_buf = os.fpathconf(x, "PC_PIPE_BUF") + os.close(x) + os.close(y) + p = subprocess.Popen([sys.executable, "-c", + 'import sys,os;' + 'sys.stdout.write(sys.stdin.read(47));' \ + 'sys.stderr.write("xyz"*%d);' \ + 'sys.stdout.write(sys.stdin.read())' % pipe_buf], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + string_to_write = "abc"*pipe_buf + (stdout, stderr) = p.communicate(string_to_write) + self.assertEqual(stdout, string_to_write) + + def test_writes_before_communicate(self): + # stdin.write before communicate() + p = subprocess.Popen([sys.executable, "-c", + 'import sys,os;' \ + 'sys.stdout.write(sys.stdin.read())'], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + p.stdin.write("banana") + (stdout, stderr) = p.communicate("split") + self.assertEqual(stdout, "bananasplit") + self.assertEqual(remove_stderr_debug_decorations(stderr), "") + + def test_universal_newlines(self): + p = subprocess.Popen([sys.executable, "-c", + 'import sys,os;' + SETBINARY + + 'sys.stdout.write("line1\\n");' + 'sys.stdout.flush();' + ... [truncated message content] |
From: <pj...@us...> - 2008-02-18 19:22:50
|
Revision: 4150 http://jython.svn.sourceforge.net/jython/?rev=4150&view=rev Author: pjenvey Date: 2008-02-18 11:22:48 -0800 (Mon, 18 Feb 2008) Log Message: ----------- jython support for subprocess. PIPEs expose the Java Process' streams, otherwise we sync to the Process' streams with a thread, like popen2. stdin=None (share stdin with the parent) could block that coupling thread forever (and stopping it with Thread.interrupt() would close the parent's stdin), so we don't support syncing it Modified Paths: -------------- trunk/jython/Lib/subprocess.py trunk/jython/Lib/test/test_subprocess.py Modified: trunk/jython/Lib/subprocess.py =================================================================== --- trunk/jython/Lib/subprocess.py 2008-02-18 19:01:57 UTC (rev 4149) +++ trunk/jython/Lib/subprocess.py 2008-02-18 19:22:48 UTC (rev 4150) @@ -352,11 +352,11 @@ import sys mswindows = (sys.platform == "win32") +jython = sys.platform.startswith("java") import os import types import traceback -import gc # Exception classes used by this module. class CalledProcessError(Exception): @@ -395,10 +395,25 @@ wShowWindow = 0 class pywintypes: error = IOError +elif jython: + import errno + import javashell + import threading + import java.io.File + import java.io.FileDescriptor + import java.io.FileOutputStream + import java.io.IOException + import java.lang.IllegalThreadStateException + import java.lang.ProcessBuilder + import java.lang.Thread + import java.nio.ByteBuffer + import org.python.core.io.RawIOBase + import org.python.core.io.StreamIO else: import select import errno import fcntl + import gc import pickle __all__ = ["Popen", "PIPE", "STDOUT", "call", "check_call", "CalledProcessError"] @@ -529,6 +544,63 @@ return ''.join(result) +if jython: + if javashell._getOsType() in ('nt', 'dos'): + # Escape the command line arguments on Windows + escape_args = lambda args: [list2cmdline([arg]) for arg in args] + else: + escape_args = lambda args: args + + + class CouplerThread(java.lang.Thread): + + """Couples a reader and writer RawIOBase. + + Streams data from the reader's read_func (a RawIOBase readinto + method) to the writer's write_func (a RawIOBase write method) in + a separate thread. Optionally calls close_func when finished + streaming or an exception occurs. + + This thread will fail safe when interrupted by Java's + Thread.interrupt. + """ + + # analagous to PC_PIPE_BUF, which is typically 512 or 4096 + bufsize = 4096 + + def __init__(self, name, read_func, write_func, close_func=None): + self.read_func = read_func + self.write_func = write_func + self.close_func = close_func + self.setName('CouplerThread-%s (%s)' % (id(self), name)) + self.setDaemon(True) + + def run(self): + buf = java.nio.ByteBuffer.allocate(self.bufsize) + while True: + try: + count = self.read_func(buf) + if count < 1: + if self.close_func: + self.close_func() + break + buf.flip() + self.write_func(buf) + buf.flip() + except IOError, ioe: + if self.close_func: + try: + self.close_func() + except: + pass + # XXX: hack, should really be a + # ClosedByInterruptError(IOError) exception + if str(ioe) == \ + 'java.nio.channels.ClosedByInterruptException': + return + raise + + class Popen(object): def __init__(self, args, bufsize=0, executable=None, stdin=None, stdout=None, stderr=None, @@ -558,6 +630,10 @@ if creationflags != 0: raise ValueError("creationflags is only supported on Windows " "platforms") + if jython: + if preexec_fn is not None: + raise ValueError("preexec_fn is not supported on the Jython " + "platform") self.stdin = None self.stdout = None @@ -608,6 +684,62 @@ os.close(errread) errread = None + if jython: + self._stdin_thread = None + self._stdout_thread = None + self._stderr_thread = None + + # 'ct' is for CouplerThread + proc = self._process + ct2cwrite = org.python.core.io.StreamIO(proc.getOutputStream(), + True) + c2ctread = org.python.core.io.StreamIO(proc.getInputStream(), True) + cterrread = org.python.core.io.StreamIO(proc.getErrorStream(), + True) + + # Use the java.lang.Process streams for PIPE, otherwise + # direct the desired file to/from the java.lang.Process + # streams in a separate thread + if p2cwrite == PIPE: + p2cwrite = ct2cwrite + else: + if p2cread is None: + # Coupling stdin is not supported: there's no way to + # cleanly interrupt it if it blocks the + # CouplerThread forever (we can Thread.interrupt() + # its CouplerThread but that closes stdin's Channel) + pass + else: + self._stdin_thread = self._coupler_thread('stdin', + p2cread.readinto, + ct2cwrite.write, + ct2cwrite.close) + self._stdin_thread.start() + + if c2pread == PIPE: + c2pread = c2ctread + else: + if c2pwrite is None: + c2pwrite = org.python.core.io.StreamIO( + java.io.FileOutputStream(java.io.FileDescriptor.out), + False) + self._stdout_thread = self._coupler_thread('stdout', + c2ctread.readinto, + c2pwrite.write) + self._stdout_thread.start() + + if errread == PIPE: + errread = cterrread + elif not self._stderr_is_stdout(errwrite, c2pwrite): + if errwrite is None: + errwrite = org.python.core.io.StreamIO( + java.io.FileOutputStream(java.io.FileDescriptor.err), + False) + self._stderr_thread = self._coupler_thread('stderr', + cterrread.readinto, + errwrite.write) + self._stderr_thread.start() + if p2cwrite is not None: self.stdin = os.fdopen(p2cwrite, 'wb', bufsize) if c2pread is not None: @@ -667,6 +799,61 @@ return self._communicate(input) + if mswindows or jython: + # + # Windows and Jython shared methods + # + def _readerthread(self, fh, buffer): + buffer.append(fh.read()) + + + def _communicate(self, input): + stdout = None # Return + stderr = None # Return + + if self.stdout: + stdout = [] + stdout_thread = threading.Thread(target=self._readerthread, + args=(self.stdout, stdout)) + stdout_thread.setDaemon(True) + stdout_thread.start() + if self.stderr: + stderr = [] + stderr_thread = threading.Thread(target=self._readerthread, + args=(self.stderr, stderr)) + stderr_thread.setDaemon(True) + stderr_thread.start() + + if self.stdin: + if input is not None: + self.stdin.write(input) + self.stdin.close() + + if self.stdout: + stdout_thread.join() + if self.stderr: + stderr_thread.join() + + # All data exchanged. Translate lists into strings. + if stdout is not None: + stdout = stdout[0] + if stderr is not None: + stderr = stderr[0] + + # Translate newlines, if requested. We cannot let the file + # object do the translation: It is based on stdio, which is + # impossible to combine with select (unless forcing no + # buffering). + if self.universal_newlines and hasattr(file, 'newlines'): + if stdout: + stdout = self._translate_newlines(stdout) + if stderr: + stderr = self._translate_newlines(stderr) + + self.wait() + return (stdout, stderr) + + if mswindows: # # Windows methods @@ -855,57 +1042,145 @@ self.returncode = GetExitCodeProcess(self._handle) return self.returncode + elif jython: + # + # Jython methods + # + def _get_handles(self, stdin, stdout, stderr): + """Construct and return tuple with IO objects: + p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite + """ + p2cread, p2cwrite = None, None + c2pread, c2pwrite = None, None + errread, errwrite = None, None - def _readerthread(self, fh, buffer): - buffer.append(fh.read()) + if stdin is None: + pass + elif stdin == PIPE: + p2cwrite = PIPE + elif isinstance(stdin, org.python.core.io.RawIOBase): + p2cread = stdin + else: + # Assuming file-like object + p2cread = stdin.fileno() + if stdout is None: + pass + elif stdout == PIPE: + c2pread = PIPE + elif isinstance(stdout, org.python.core.io.RawIOBase): + c2pwrite = stdout + else: + # Assuming file-like object + c2pwrite = stdout.fileno() - def _communicate(self, input): - stdout = None # Return - stderr = None # Return + if stderr is None: + pass + elif stderr == PIPE: + errread = PIPE + elif stderr == STDOUT or \ + isinstance(stderr, org.python.core.io.RawIOBase): + errwrite = stderr + else: + # Assuming file-like object + errwrite = stderr.fileno() + + return (p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite) - if self.stdout: - stdout = [] - stdout_thread = threading.Thread(target=self._readerthread, - args=(self.stdout, stdout)) - stdout_thread.setDaemon(True) - stdout_thread.start() - if self.stderr: - stderr = [] - stderr_thread = threading.Thread(target=self._readerthread, - args=(self.stderr, stderr)) - stderr_thread.setDaemon(True) - stderr_thread.start() - if self.stdin: - if input is not None: - self.stdin.write(input) - self.stdin.close() + def _stderr_is_stdout(self, errwrite, c2pwrite): + """Determine if the subprocess' stderr should be redirected to + stdout + """ + return errwrite == STDOUT or c2pwrite not in (None, PIPE) and \ + c2pwrite is errwrite - if self.stdout: - stdout_thread.join() - if self.stderr: - stderr_thread.join() - # All data exchanged. Translate lists into strings. - if stdout is not None: - stdout = stdout[0] - if stderr is not None: - stderr = stderr[0] + def _coupler_thread(self, *args, **kwargs): + """Return a CouplerThread""" + return CouplerThread(*args, **kwargs) - # Translate newlines, if requested. We cannot let the file - # object do the translation: It is based on stdio, which is - # impossible to combine with select (unless forcing no - # buffering). - if self.universal_newlines and hasattr(file, 'newlines'): - if stdout: - stdout = self._translate_newlines(stdout) - if stderr: - stderr = self._translate_newlines(stderr) - self.wait() - return (stdout, stderr) + def _execute_child(self, args, executable, preexec_fn, close_fds, + cwd, env, universal_newlines, + startupinfo, creationflags, shell, + p2cread, p2cwrite, + c2pread, c2pwrite, + errread, errwrite): + """Execute program (Java version)""" + if isinstance(args, types.StringTypes): + args = [args] + else: + args = list(args) + args = escape_args(args) + + if shell: + args = javashell._shellEnv.cmd + args + + if executable is not None: + args[0] = executable + + builder = java.lang.ProcessBuilder(args) + if env is not None: + builder_env = builder.environment() + builder_env.clear() + builder_env.putAll(dict(env)) + + if cwd is None: + cwd = os.getcwd() + elif not os.path.exists(cwd): + raise OSError(errno.ENOENT, errno.strerror(errno.ENOENT), cwd) + elif not os.path.isdir(cwd): + raise OSError(errno.ENOTDIR, errno.strerror(errno.ENOENT), cwd) + builder.directory(java.io.File(cwd)) + + # Let Java manage redirection of stderr to stdout (it's more + # accurate at doing so than CouplerThreads). We redirect not + # only when stderr is marked as STDOUT, but also when + # c2pwrite is errwrite + if self._stderr_is_stdout(errwrite, c2pwrite): + builder.redirectErrorStream(True) + + try: + self._process = builder.start() + except java.io.IOException: + executable = os.path.join(cwd, args[0]) + if not os.path.exists(executable): + raise OSError(errno.ENOENT, errno.strerror(errno.ENOENT), + args[0]) + raise OSError(errno.EACCES, errno.strerror(errno.EACCES), + args[0]) + self._child_created = True + + + def poll(self, _deadstate=None): + """Check if child process has terminated. Returns returncode + attribute.""" + if self.returncode is None: + try: + return self._process.exitValue() + except java.lang.IllegalThreadStateException: + pass + return self.returncode + + + def wait(self): + """Wait for child process to terminate. Returns returncode + attribute.""" + if self.returncode is None: + self.returncode = self._process.waitFor() + for coupler in (self._stdout_thread, self._stderr_thread): + if coupler: + coupler.join() + if self._stdin_thread: + # The stdin thread may be blocked forever, forcibly + # stop it + self._stdin_thread.interrupt() + return self.returncode + else: # # POSIX methods @@ -1243,8 +1518,39 @@ p.wait() +def _demo_jython(): + # + # Example 1: Return the number of processors on this machine + # + print "Running a jython subprocess to return the number of processors..." + p = Popen([sys.executable, "-c", + 'import sys;' \ + 'from java.lang import Runtime;' \ + 'sys.exit(Runtime.getRuntime().availableProcessors())']) + print p.wait() + + # + # Example 2: Connecting several subprocesses + # + print "Connecting two jython subprocesses..." + p1 = Popen([sys.executable, "-c", + 'import os;' \ + 'print os.environ["foo"]'], env=dict(foo='bar'), + stdout=PIPE) + p2 = Popen([sys.executable, "-c", + 'import os, sys;' \ + 'their_foo = sys.stdin.read().strip();' \ + 'my_foo = os.environ["foo"];' \ + 'msg = "Their env\'s foo: %r, My env\'s foo: %r";' \ + 'print msg % (their_foo, my_foo)'], + env=dict(foo='baz'), stdin=p1.stdout, stdout=PIPE) + print p2.communicate()[0] + + if __name__ == "__main__": if mswindows: _demo_windows() + elif jython: + _demo_jython() else: _demo_posix() Modified: trunk/jython/Lib/test/test_subprocess.py =================================================================== --- trunk/jython/Lib/test/test_subprocess.py 2008-02-18 19:01:57 UTC (rev 4149) +++ trunk/jython/Lib/test/test_subprocess.py 2008-02-18 19:22:48 UTC (rev 4150) @@ -2,13 +2,13 @@ from test import test_support import subprocess import sys -import signal import os import tempfile import time import re mswindows = (sys.platform == "win32") +jython = sys.platform.startswith("java") # # Depends on the following external programs: Python @@ -17,9 +17,15 @@ if mswindows: SETBINARY = ('import msvcrt; msvcrt.setmode(sys.stdout.fileno(), ' 'os.O_BINARY);') +elif jython: + SETBINARY = ('import os,sys;' + 'sys.stdout = os.fdopen(sys.stdout.fileno(), "wb");') else: SETBINARY = '' +if not jython: + import signal + # In a debug build, stuff like "[6580 refs]" is printed to stderr at # shutdown time. That frustrates tests trying to check stderr produced # from a spawned Python process. @@ -237,7 +243,8 @@ def test_stdout_filedes_of_stdout(self): # stdout is set to 1 (#1531862). cmd = r"import sys, os; sys.exit(os.write(sys.stdout.fileno(), '.\n'))" - rc = subprocess.call([sys.executable, "-c", cmd], stdout=1) + rc = subprocess.call([sys.executable, "-c", cmd], + stdout=sys.__stdout__.fileno()) self.assertEquals(rc, 2) def test_cwd(self): @@ -316,13 +323,13 @@ # communicate() with writes larger than pipe_buf # This test will probably deadlock rather than fail, if # communicate() does not work properly. - x, y = os.pipe() - if mswindows: + if mswindows or jython: pipe_buf = 512 else: + x, y = os.pipe() pipe_buf = os.fpathconf(x, "PC_PIPE_BUF") - os.close(x) - os.close(y) + os.close(x) + os.close(y) p = subprocess.Popen([sys.executable, "-c", 'import sys,os;' 'sys.stdout.write(sys.stdin.read(47));' \ @@ -403,8 +410,13 @@ def test_no_leaking(self): # Make sure we leak no resources if not hasattr(test_support, "is_resource_enabled") \ - or test_support.is_resource_enabled("subprocess") and not mswindows: - max_handles = 1026 # too much for most UNIX systems + or test_support.is_resource_enabled("subprocess") and \ + not mswindows: + # 1026 is too much for most UNIX systems + max_handles = jython and 65 or 1026 + elif jython: + # Spawning java processes takes a long time + return else: max_handles = 65 for i in range(max_handles): @@ -475,7 +487,7 @@ # # POSIX tests # - if not mswindows: + if not mswindows and not jython: def test_exceptions(self): # catched & re-raised exceptions try: @@ -656,6 +668,39 @@ self.assertEqual(rc, 47) + # + # Jython tests + # + if jython: + def test_cwd_exception(self): + # catched & re-raised exceptions + self.assertRaises(OSError, subprocess.call, + [sys.executable, "-c", ""], + cwd="/this/path/does/not/exist") + + def test_path_exception(self): + tf = tempfile.NamedTemporaryFile() + self.assertRaises(OSError, subprocess.call, + tf.name) + self.assertRaises(OSError, subprocess.call, + "/this/path/does/not/exist") + + def test_invalid_args(self): + # invalid arguments should raise ValueError + self.assertRaises(ValueError, subprocess.call, + [sys.executable, + "-c", "import sys; sys.exit(47)"], + startupinfo=47) + self.assertRaises(ValueError, subprocess.call, + [sys.executable, + "-c", "import sys; sys.exit(47)"], + creationflags=47) + self.assertRaises(ValueError, subprocess.call, + [sys.executable, + "-c", "import sys; sys.exit(47)"], + preexec_fn=lambda: 1) + + def test_main(): test_support.run_unittest(ProcessTestCase) if hasattr(test_support, "reap_children"): This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <pj...@us...> - 2008-02-24 02:38:23
|
Revision: 4170 http://jython.svn.sourceforge.net/jython/?rev=4170&view=rev Author: pjenvey Date: 2008-02-23 18:38:21 -0800 (Sat, 23 Feb 2008) Log Message: ----------- o fix samefile and _resolve_link to work on Jython via java.io.File o disable sameopenfile, samestat and ismount due to the lack of fstat and st_ino/st_dev Modified Paths: -------------- trunk/jython/Lib/posixpath.py trunk/jython/Lib/test/test_posixpath.py Modified: trunk/jython/Lib/posixpath.py =================================================================== --- trunk/jython/Lib/posixpath.py 2008-02-24 02:11:03 UTC (rev 4169) +++ trunk/jython/Lib/posixpath.py 2008-02-24 02:38:21 UTC (rev 4170) @@ -10,14 +10,19 @@ for manipulation of the pathname component of URLs. """ +import java.io.File +import java.io.IOException +# XXX: os (org.python.modules.os) is broken when we're imported: look at +# javaos.name instead +import javaos import os import stat __all__ = ["normcase","isabs","join","splitdrive","split","splitext", "basename","dirname","commonprefix","getsize","getmtime", "getatime","getctime","islink","exists","lexists","isdir","isfile", - "ismount","walk","expanduser","expandvars","normpath","abspath", - "samefile","sameopenfile","samestat", + "walk","expanduser","expandvars","normpath","abspath", + "samefile", "curdir","pardir","sep","pathsep","defpath","altsep","extsep", "devnull","realpath","supports_unicode_filenames"] @@ -213,53 +218,64 @@ # Are two filenames really pointing to the same file? -def samefile(f1, f2): - """Test whether two pathnames reference the same actual file""" - s1 = os.stat(f1) - s2 = os.stat(f2) - return samestat(s1, s2) +if javaos.name == 'java': + def samefile(f1, f2): + """Test whether two pathnames reference the same actual file""" + canon1 = java.io.File(_ensure_str(f1)).getCanonicalPath() + canon2 = java.io.File(_ensure_str(f2)).getCanonicalPath() + return canon1 == canon2 +else: + def samefile(f1, f2): + """Test whether two pathnames reference the same actual file""" + s1 = os.stat(f1) + s2 = os.stat(f2) + return samestat(s1, s2) -# Are two open files really referencing the same file? -# (Not necessarily the same file descriptor!) +# XXX: Plain Jython lacks fstat and st_ino/st_dev +if javaos.name != 'java': + # Are two open files really referencing the same file? + # (Not necessarily the same file descriptor!) -def sameopenfile(fp1, fp2): - """Test whether two open file objects reference the same file""" - s1 = os.fstat(fp1) - s2 = os.fstat(fp2) - return samestat(s1, s2) + def sameopenfile(fp1, fp2): + """Test whether two open file objects reference the same file""" + s1 = os.fstat(fp1) + s2 = os.fstat(fp2) + return samestat(s1, s2) -# Are two stat buffers (obtained from stat, fstat or lstat) -# describing the same file? + # Are two stat buffers (obtained from stat, fstat or lstat) + # describing the same file? -def samestat(s1, s2): - """Test whether two stat buffers reference the same file""" - return s1.st_ino == s2.st_ino and \ - s1.st_dev == s2.st_dev + def samestat(s1, s2): + """Test whether two stat buffers reference the same file""" + return s1.st_ino == s2.st_ino and \ + s1.st_dev == s2.st_dev -# Is a path a mount point? -# (Does this work for all UNIXes? Is it even guaranteed to work by Posix?) + # Is a path a mount point? + # (Does this work for all UNIXes? Is it even guaranteed to work by Posix?) -def ismount(path): - """Test whether a path is a mount point""" - try: - s1 = os.lstat(path) - s2 = os.lstat(join(path, '..')) - except os.error: - return False # It doesn't exist -- so not a mount point :-) - dev1 = s1.st_dev - dev2 = s2.st_dev - if dev1 != dev2: - return True # path/.. on a different device as path - ino1 = s1.st_ino - ino2 = s2.st_ino - if ino1 == ino2: - return True # path/.. is the same i-node as path - return False + def ismount(path): + """Test whether a path is a mount point""" + try: + s1 = os.lstat(path) + s2 = os.lstat(join(path, '..')) + except os.error: + return False # It doesn't exist -- so not a mount point :-) + dev1 = s1.st_dev + dev2 = s2.st_dev + if dev1 != dev2: + return True # path/.. on a different device as path + ino1 = s1.st_ino + ino2 = s2.st_ino + if ino1 == ino2: + return True # path/.. is the same i-node as path + return False + __all__.extend(["sameopenfile", "samestat", "ismount"]) + # Directory tree walk. # For each directory under top (including top itself, but excluding # '.' and '..'), func(arg, dirname, filenames) is called, where @@ -317,17 +333,12 @@ i = len(path) if i == 1: if 'HOME' not in os.environ: - import pwd - userhome = pwd.getpwuid(os.getuid()).pw_dir + return path else: userhome = os.environ['HOME'] else: - import pwd - try: - pwent = pwd.getpwnam(path[1:i]) - except KeyError: - return path - userhome = pwent.pw_dir + # XXX: Jython lacks the pwd module: '~user' isn't supported + return path userhome = userhome.rstrip('/') return userhome + path[i:] @@ -430,24 +441,55 @@ return abspath(filename) -def _resolve_link(path): - """Internal helper function. Takes a path and follows symlinks - until we either arrive at something that isn't a symlink, or - encounter a path we've seen before (meaning that there's a loop). - """ - paths_seen = [] - while islink(path): - if path in paths_seen: - # Already seen this path, so we must have a symlink loop +if javaos.name == 'java': + def _resolve_link(path): + """Internal helper function. Takes a path and follows symlinks + until we either arrive at something that isn't a symlink, or + encounter a path we've seen before (meaning that there's a loop). + """ + try: + return str(java.io.File(path).getCanonicalPath()) + except java.io.IOException: return None - paths_seen.append(path) - # Resolve where the link points to - resolved = os.readlink(path) - if not isabs(resolved): - dir = dirname(path) - path = normpath(join(dir, resolved)) - else: - path = normpath(resolved) - return path +else: + def _resolve_link(path): + """Internal helper function. Takes a path and follows symlinks + until we either arrive at something that isn't a symlink, or + encounter a path we've seen before (meaning that there's a loop). + """ + paths_seen = [] + while islink(path): + if path in paths_seen: + # Already seen this path, so we must have a symlink loop + return None + paths_seen.append(path) + # Resolve where the link points to + resolved = os.readlink(path) + if not isabs(resolved): + dir = dirname(path) + path = normpath(join(dir, resolved)) + else: + path = normpath(resolved) + return path + +def _ensure_str(obj): + """Ensure obj is a string, otherwise raise a TypeError""" + if isinstance(obj, basestring): + return obj + raise TypeError('coercing to Unicode: need string or buffer, %s found' % \ + _type_name(obj)) + + +def _type_name(obj): + """Determine the appropriate type name of obj for display""" + TPFLAGS_HEAPTYPE = 1 << 9 + type_name = '' + obj_type = type(obj) + is_heap = obj_type.__flags__ & TPFLAGS_HEAPTYPE == TPFLAGS_HEAPTYPE + if not is_heap and obj_type.__module__ != '__builtin__': + type_name = '%s.' % obj_type.__module__ + type_name += obj_type.__name__ + return type_name + supports_unicode_filenames = False Modified: trunk/jython/Lib/test/test_posixpath.py =================================================================== --- trunk/jython/Lib/test/test_posixpath.py 2008-02-24 02:11:03 UTC (rev 4169) +++ trunk/jython/Lib/test/test_posixpath.py 2008-02-24 02:38:21 UTC (rev 4170) @@ -284,59 +284,60 @@ self.assertRaises(TypeError, posixpath.samefile) - def test_samestat(self): - f = open(test_support.TESTFN + "1", "wb") - try: - f.write("foo") - f.close() - self.assertIs( - posixpath.samestat( - os.stat(test_support.TESTFN + "1"), - os.stat(test_support.TESTFN + "1") - ), - True - ) - # If we don't have links, assume that os.stat() doesn't return resonable - # inode information and thus, that samefile() doesn't work - if hasattr(os, "symlink"): + if os.name != 'java': + def test_samestat(self): + f = open(test_support.TESTFN + "1", "wb") + try: + f.write("foo") + f.close() + self.assertIs( + posixpath.samestat( + os.stat(test_support.TESTFN + "1"), + os.stat(test_support.TESTFN + "1") + ), + True + ) + # If we don't have links, assume that os.stat() doesn't return resonable + # inode information and thus, that samefile() doesn't work if hasattr(os, "symlink"): - os.symlink(test_support.TESTFN + "1", test_support.TESTFN + "2") + if hasattr(os, "symlink"): + os.symlink(test_support.TESTFN + "1", test_support.TESTFN + "2") + self.assertIs( + posixpath.samestat( + os.stat(test_support.TESTFN + "1"), + os.stat(test_support.TESTFN + "2") + ), + True + ) + os.remove(test_support.TESTFN + "2") + f = open(test_support.TESTFN + "2", "wb") + f.write("bar") + f.close() self.assertIs( posixpath.samestat( os.stat(test_support.TESTFN + "1"), os.stat(test_support.TESTFN + "2") ), - True + False ) + finally: + if not f.close(): + f.close() + try: + os.remove(test_support.TESTFN + "1") + except os.error: + pass + try: os.remove(test_support.TESTFN + "2") - f = open(test_support.TESTFN + "2", "wb") - f.write("bar") - f.close() - self.assertIs( - posixpath.samestat( - os.stat(test_support.TESTFN + "1"), - os.stat(test_support.TESTFN + "2") - ), - False - ) - finally: - if not f.close(): - f.close() - try: - os.remove(test_support.TESTFN + "1") - except os.error: - pass - try: - os.remove(test_support.TESTFN + "2") - except os.error: - pass + except os.error: + pass - self.assertRaises(TypeError, posixpath.samestat) + self.assertRaises(TypeError, posixpath.samestat) - def test_ismount(self): - self.assertIs(posixpath.ismount("/"), True) + def test_ismount(self): + self.assertIs(posixpath.ismount("/"), True) - self.assertRaises(TypeError, posixpath.ismount) + self.assertRaises(TypeError, posixpath.ismount) def test_expanduser(self): self.assertEqual(posixpath.expanduser("foo"), "foo") This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |
From: <cg...@us...> - 2008-02-28 16:55:41
|
Revision: 4185 http://jython.svn.sourceforge.net/jython/?rev=4185&view=rev Author: cgroves Date: 2008-02-28 08:55:33 -0800 (Thu, 28 Feb 2008) Log Message: ----------- Tabs to spaces Modified Paths: -------------- trunk/jython/Lib/dbexts.py trunk/jython/Lib/distutils/sysconfig.py trunk/jython/Lib/isql.py trunk/jython/Lib/javapath.py trunk/jython/Lib/marshal.py trunk/jython/Lib/pawt/__init__.py trunk/jython/Lib/pawt/swing.py trunk/jython/Lib/popen2.py trunk/jython/Lib/select.py trunk/jython/Lib/site.py trunk/jython/Lib/test/Graph.py trunk/jython/Lib/test/bugs/bugs100.py trunk/jython/Lib/test/bugs/bugs101.py trunk/jython/Lib/test/re_tests.py trunk/jython/Lib/test/regrtest.py trunk/jython/Lib/test/test_builtin_jy.py trunk/jython/Lib/test/test_cmath.py trunk/jython/Lib/test/test_enumerate.py trunk/jython/Lib/test/test_func_syntax_jy.py trunk/jython/Lib/test/test_java_integration.py trunk/jython/Lib/test/test_jbasic.py trunk/jython/Lib/test/test_jser.py trunk/jython/Lib/test/test_jsubclass.py trunk/jython/Lib/test/test_jy_compile.py trunk/jython/Lib/test/test_list_jy.py trunk/jython/Lib/test/test_mailbox.py trunk/jython/Lib/test/test_pow.py trunk/jython/Lib/test/test_re_jy.py trunk/jython/Lib/test/test_sax.py trunk/jython/Lib/test/test_sort.py trunk/jython/Lib/test/test_subclasses_jy.py trunk/jython/Lib/test/test_support.py trunk/jython/Lib/test/test_thread.py trunk/jython/Lib/test/test_thread_local.py trunk/jython/Lib/test/test_zlib.py trunk/jython/Lib/test/whrandom.py trunk/jython/Lib/test/zxjdbc/dbextstest.py trunk/jython/Lib/test/zxjdbc/jndi.py trunk/jython/Lib/test/zxjdbc/runner.py trunk/jython/Lib/test/zxjdbc/sptest.py trunk/jython/Lib/test/zxjdbc/zxtest.py trunk/jython/Lib/xml/dom/minidom.py trunk/jython/Lib/xml/dom/pulldom.py trunk/jython/Lib/xml/sax/drivers2/drv_javasax.py Modified: trunk/jython/Lib/dbexts.py =================================================================== --- trunk/jython/Lib/dbexts.py 2008-02-27 12:52:31 UTC (rev 4184) +++ trunk/jython/Lib/dbexts.py 2008-02-28 16:55:33 UTC (rev 4185) @@ -56,671 +56,671 @@ choose = lambda bool, a, b: (bool and [a] or [b])[0] def console(rows, headers=()): - """Format the results into a list of strings (one for each row): + """Format the results into a list of strings (one for each row): - <header> - <headersep> - <row1> - <row2> - ... + <header> + <headersep> + <row1> + <row2> + ... - headers may be given as list of strings. + headers may be given as list of strings. - Columns are separated by colsep; the header is separated from - the result set by a line of headersep characters. + Columns are separated by colsep; the header is separated from + the result set by a line of headersep characters. - The function calls stringify to format the value data into a string. - It defaults to calling str() and striping leading and trailing whitespace. + The function calls stringify to format the value data into a string. + It defaults to calling str() and striping leading and trailing whitespace. - - copied and modified from mxODBC - """ + - copied and modified from mxODBC + """ - # Check row entry lengths - output = [] - headers = map(lambda header: header.upper(), list(map(lambda x: x or "", headers))) - collen = map(len,headers) - output.append(headers) - if rows and len(rows) > 0: - for row in rows: - row = map(lambda x: str(x), row) - for i in range(len(row)): - entry = row[i] - if collen[i] < len(entry): - collen[i] = len(entry) - output.append(row) - if len(output) == 1: - affected = "0 rows affected" - elif len(output) == 2: - affected = "1 row affected" - else: - affected = "%d rows affected" % (len(output) - 1) + # Check row entry lengths + output = [] + headers = map(lambda header: header.upper(), list(map(lambda x: x or "", headers))) + collen = map(len,headers) + output.append(headers) + if rows and len(rows) > 0: + for row in rows: + row = map(lambda x: str(x), row) + for i in range(len(row)): + entry = row[i] + if collen[i] < len(entry): + collen[i] = len(entry) + output.append(row) + if len(output) == 1: + affected = "0 rows affected" + elif len(output) == 2: + affected = "1 row affected" + else: + affected = "%d rows affected" % (len(output) - 1) - # Format output - for i in range(len(output)): - row = output[i] - l = [] - for j in range(len(row)): - l.append('%-*s' % (collen[j],row[j])) - output[i] = " | ".join(l) + # Format output + for i in range(len(output)): + row = output[i] + l = [] + for j in range(len(row)): + l.append('%-*s' % (collen[j],row[j])) + output[i] = " | ".join(l) - # Insert header separator - totallen = len(output[0]) - output[1:1] = ["-"*(totallen/len("-"))] - output.append("\n" + affected) - return output + # Insert header separator + totallen = len(output[0]) + output[1:1] = ["-"*(totallen/len("-"))] + output.append("\n" + affected) + return output def html(rows, headers=()): - output = [] - output.append('<table class="results">') - output.append('<tr class="headers">') - headers = map(lambda x: '<td class="header">%s</td>' % (x.upper()), list(headers)) - map(output.append, headers) - output.append('</tr>') - if rows and len(rows) > 0: - for row in rows: - output.append('<tr class="row">') - row = map(lambda x: '<td class="value">%s</td>' % (x), row) - map(output.append, row) - output.append('</tr>') - output.append('</table>') - return output + output = [] + output.append('<table class="results">') + output.append('<tr class="headers">') + headers = map(lambda x: '<td class="header">%s</td>' % (x.upper()), list(headers)) + map(output.append, headers) + output.append('</tr>') + if rows and len(rows) > 0: + for row in rows: + output.append('<tr class="row">') + row = map(lambda x: '<td class="value">%s</td>' % (x), row) + map(output.append, row) + output.append('</tr>') + output.append('</table>') + return output comments = lambda x: re.compile("{.*?}", re.S).sub("", x, 0) class mxODBCProxy: - """Wraps mxODBC to provide proxy support for zxJDBC's additional parameters.""" - def __init__(self, c): - self.c = c - def __getattr__(self, name): - if name == "execute": - return self.execute - elif name == "gettypeinfo": - return self.gettypeinfo - else: - return getattr(self.c, name) - def execute(self, sql, params=None, bindings=None, maxrows=None): - if params: - self.c.execute(sql, params) - else: - self.c.execute(sql) - def gettypeinfo(self, typeid=None): - if typeid: - self.c.gettypeinfo(typeid) + """Wraps mxODBC to provide proxy support for zxJDBC's additional parameters.""" + def __init__(self, c): + self.c = c + def __getattr__(self, name): + if name == "execute": + return self.execute + elif name == "gettypeinfo": + return self.gettypeinfo + else: + return getattr(self.c, name) + def execute(self, sql, params=None, bindings=None, maxrows=None): + if params: + self.c.execute(sql, params) + else: + self.c.execute(sql) + def gettypeinfo(self, typeid=None): + if typeid: + self.c.gettypeinfo(typeid) class executor: - """Handles the insertion of values given dynamic data.""" - def __init__(self, table, cols): - self.cols = cols - self.table = table - if self.cols: - self.sql = "insert into %s (%s) values (%s)" % (table, ",".join(self.cols), ",".join(("?",) * len(self.cols))) - else: - self.sql = "insert into %s values (%%s)" % (table) - def execute(self, db, rows, bindings): - assert rows and len(rows) > 0, "must have at least one row" - if self.cols: - sql = self.sql - else: - sql = self.sql % (",".join(("?",) * len(rows[0]))) - db.raw(sql, rows, bindings) + """Handles the insertion of values given dynamic data.""" + def __init__(self, table, cols): + self.cols = cols + self.table = table + if self.cols: + self.sql = "insert into %s (%s) values (%s)" % (table, ",".join(self.cols), ",".join(("?",) * len(self.cols))) + else: + self.sql = "insert into %s values (%%s)" % (table) + def execute(self, db, rows, bindings): + assert rows and len(rows) > 0, "must have at least one row" + if self.cols: + sql = self.sql + else: + sql = self.sql % (",".join(("?",) * len(rows[0]))) + db.raw(sql, rows, bindings) def connect(dbname): - return dbexts(dbname) + return dbexts(dbname) def lookup(dbname): - return dbexts(jndiname=dbname) + return dbexts(jndiname=dbname) class dbexts: - def __init__(self, dbname=None, cfg=None, formatter=console, autocommit=0, jndiname=None, out=None): - self.verbose = 1 - self.results = [] - self.headers = [] - self.autocommit = autocommit - self.formatter = formatter - self.out = out - self.lastrowid = None - self.updatecount = None + def __init__(self, dbname=None, cfg=None, formatter=console, autocommit=0, jndiname=None, out=None): + self.verbose = 1 + self.results = [] + self.headers = [] + self.autocommit = autocommit + self.formatter = formatter + self.out = out + self.lastrowid = None + self.updatecount = None - if not jndiname: - if cfg == None: - fn = os.path.join(os.path.split(__file__)[0], "dbexts.ini") - if not os.path.exists(fn): - fn = os.path.join(os.environ['HOME'], ".dbexts") - self.dbs = IniParser(fn) - elif isinstance(cfg, IniParser): - self.dbs = cfg - else: - self.dbs = IniParser(cfg) - if dbname == None: dbname = self.dbs[("default", "name")] + if not jndiname: + if cfg == None: + fn = os.path.join(os.path.split(__file__)[0], "dbexts.ini") + if not os.path.exists(fn): + fn = os.path.join(os.environ['HOME'], ".dbexts") + self.dbs = IniParser(fn) + elif isinstance(cfg, IniParser): + self.dbs = cfg + else: + self.dbs = IniParser(cfg) + if dbname == None: dbname = self.dbs[("default", "name")] - if __OS__ == 'java': + if __OS__ == 'java': - from com.ziclix.python.sql import zxJDBC - database = zxJDBC - if not jndiname: - t = self.dbs[("jdbc", dbname)] - self.dburl, dbuser, dbpwd, jdbcdriver = t['url'], t['user'], t['pwd'], t['driver'] - if t.has_key('datahandler'): - self.datahandler = [] - for dh in t['datahandler'].split(','): - classname = dh.split(".")[-1] - datahandlerclass = __import__(dh, globals(), locals(), classname) - self.datahandler.append(datahandlerclass) - keys = [x for x in t.keys() if x not in ['url', 'user', 'pwd', 'driver', 'datahandler', 'name']] - props = {} - for a in keys: - props[a] = t[a] - self.db = apply(database.connect, (self.dburl, dbuser, dbpwd, jdbcdriver), props) - else: - self.db = database.lookup(jndiname) - self.db.autocommit = self.autocommit + from com.ziclix.python.sql import zxJDBC + database = zxJDBC + if not jndiname: + t = self.dbs[("jdbc", dbname)] + self.dburl, dbuser, dbpwd, jdbcdriver = t['url'], t['user'], t['pwd'], t['driver'] + if t.has_key('datahandler'): + self.datahandler = [] + for dh in t['datahandler'].split(','): + classname = dh.split(".")[-1] + datahandlerclass = __import__(dh, globals(), locals(), classname) + self.datahandler.append(datahandlerclass) + keys = [x for x in t.keys() if x not in ['url', 'user', 'pwd', 'driver', 'datahandler', 'name']] + props = {} + for a in keys: + props[a] = t[a] + self.db = apply(database.connect, (self.dburl, dbuser, dbpwd, jdbcdriver), props) + else: + self.db = database.lookup(jndiname) + self.db.autocommit = self.autocommit - elif __OS__ == 'nt': + elif __OS__ == 'nt': - for modname in ["mx.ODBC.Windows", "ODBC.Windows"]: - try: - database = __import__(modname, globals(), locals(), "Windows") - break - except: - continue - else: - raise ImportError("unable to find appropriate mxODBC module") + for modname in ["mx.ODBC.Windows", "ODBC.Windows"]: + try: + database = __import__(modname, globals(), locals(), "Windows") + break + except: + continue + else: + raise ImportError("unable to find appropriate mxODBC module") - t = self.dbs[("odbc", dbname)] - self.dburl, dbuser, dbpwd = t['url'], t['user'], t['pwd'] - self.db = database.Connect(self.dburl, dbuser, dbpwd, clear_auto_commit=1) + t = self.dbs[("odbc", dbname)] + self.dburl, dbuser, dbpwd = t['url'], t['user'], t['pwd'] + self.db = database.Connect(self.dburl, dbuser, dbpwd, clear_auto_commit=1) - self.dbname = dbname - for a in database.sqltype.keys(): - setattr(self, database.sqltype[a], a) - for a in dir(database): - try: - p = getattr(database, a) - if issubclass(p, Exception): - setattr(self, a, p) - except: - continue - del database + self.dbname = dbname + for a in database.sqltype.keys(): + setattr(self, database.sqltype[a], a) + for a in dir(database): + try: + p = getattr(database, a) + if issubclass(p, Exception): + setattr(self, a, p) + except: + continue + del database - def __str__(self): - return self.dburl + def __str__(self): + return self.dburl - def __repr__(self): - return self.dburl + def __repr__(self): + return self.dburl - def __getattr__(self, name): - if "cfg" == name: - return self.dbs.cfg - raise AttributeError("'dbexts' object has no attribute '%s'" % (name)) + def __getattr__(self, name): + if "cfg" == name: + return self.dbs.cfg + raise AttributeError("'dbexts' object has no attribute '%s'" % (name)) - def close(self): - """ close the connection to the database """ - self.db.close() + def close(self): + """ close the connection to the database """ + self.db.close() - def begin(self, style=None): - """ reset ivars and return a new cursor, possibly binding an auxiliary datahandler """ - self.headers, self.results = [], [] - if style: - c = self.db.cursor(style) - else: - c = self.db.cursor() - if __OS__ == 'java': - if hasattr(self, 'datahandler'): - for dh in self.datahandler: - c.datahandler = dh(c.datahandler) - else: - c = mxODBCProxy(c) - return c + def begin(self, style=None): + """ reset ivars and return a new cursor, possibly binding an auxiliary datahandler """ + self.headers, self.results = [], [] + if style: + c = self.db.cursor(style) + else: + c = self.db.cursor() + if __OS__ == 'java': + if hasattr(self, 'datahandler'): + for dh in self.datahandler: + c.datahandler = dh(c.datahandler) + else: + c = mxODBCProxy(c) + return c - def commit(self, cursor=None, close=1): - """ commit the cursor and create the result set """ - if cursor and cursor.description: - self.headers = cursor.description - self.results = cursor.fetchall() - if hasattr(cursor, "nextset"): - s = cursor.nextset() - while s: - self.results += cursor.fetchall() - s = cursor.nextset() - if hasattr(cursor, "lastrowid"): - self.lastrowid = cursor.lastrowid - if hasattr(cursor, "updatecount"): - self.updatecount = cursor.updatecount - if not self.autocommit or cursor is None: - if not self.db.autocommit: - self.db.commit() - if cursor and close: cursor.close() + def commit(self, cursor=None, close=1): + """ commit the cursor and create the result set """ + if cursor and cursor.description: + self.headers = cursor.description + self.results = cursor.fetchall() + if hasattr(cursor, "nextset"): + s = cursor.nextset() + while s: + self.results += cursor.fetchall() + s = cursor.nextset() + if hasattr(cursor, "lastrowid"): + self.lastrowid = cursor.lastrowid + if hasattr(cursor, "updatecount"): + self.updatecount = cursor.updatecount + if not self.autocommit or cursor is None: + if not self.db.autocommit: + self.db.commit() + if cursor and close: cursor.close() - def rollback(self): - """ rollback the cursor """ - self.db.rollback() + def rollback(self): + """ rollback the cursor """ + self.db.rollback() - def prepare(self, sql): - """ prepare the sql statement """ - cur = self.begin() - try: - return cur.prepare(sql) - finally: - self.commit(cur) + def prepare(self, sql): + """ prepare the sql statement """ + cur = self.begin() + try: + return cur.prepare(sql) + finally: + self.commit(cur) - def display(self): - """ using the formatter, display the results """ - if self.formatter and self.verbose > 0: - res = self.results - if res: - print >> self.out, "" - for a in self.formatter(res, map(lambda x: x[0], self.headers)): - print >> self.out, a - print >> self.out, "" + def display(self): + """ using the formatter, display the results """ + if self.formatter and self.verbose > 0: + res = self.results + if res: + print >> self.out, "" + for a in self.formatter(res, map(lambda x: x[0], self.headers)): + print >> self.out, a + print >> self.out, "" - def __execute__(self, sql, params=None, bindings=None, maxrows=None): - """ the primary execution method """ - cur = self.begin() - try: - if bindings: - cur.execute(sql, params, bindings, maxrows=maxrows) - elif params: - cur.execute(sql, params, maxrows=maxrows) - else: - cur.execute(sql, maxrows=maxrows) - finally: - self.commit(cur, close=isinstance(sql, StringType)) + def __execute__(self, sql, params=None, bindings=None, maxrows=None): + """ the primary execution method """ + cur = self.begin() + try: + if bindings: + cur.execute(sql, params, bindings, maxrows=maxrows) + elif params: + cur.execute(sql, params, maxrows=maxrows) + else: + cur.execute(sql, maxrows=maxrows) + finally: + self.commit(cur, close=isinstance(sql, StringType)) - def isql(self, sql, params=None, bindings=None, maxrows=None): - """ execute and display the sql """ - self.raw(sql, params, bindings, maxrows=maxrows) - self.display() + def isql(self, sql, params=None, bindings=None, maxrows=None): + """ execute and display the sql """ + self.raw(sql, params, bindings, maxrows=maxrows) + self.display() - def raw(self, sql, params=None, bindings=None, delim=None, comments=comments, maxrows=None): - """ execute the sql and return a tuple of (headers, results) """ - if delim: - headers = [] - results = [] - if type(sql) == type(StringType): - if comments: sql = comments(sql) - statements = filter(lambda x: len(x) > 0, - map(lambda statement: statement.strip(), sql.split(delim))) - else: - statements = [sql] - for a in statements: - self.__execute__(a, params, bindings, maxrows=maxrows) - headers.append(self.headers) - results.append(self.results) - self.headers = headers - self.results = results - else: - self.__execute__(sql, params, bindings, maxrows=maxrows) - return (self.headers, self.results) + def raw(self, sql, params=None, bindings=None, delim=None, comments=comments, maxrows=None): + """ execute the sql and return a tuple of (headers, results) """ + if delim: + headers = [] + results = [] + if type(sql) == type(StringType): + if comments: sql = comments(sql) + statements = filter(lambda x: len(x) > 0, + map(lambda statement: statement.strip(), sql.split(delim))) + else: + statements = [sql] + for a in statements: + self.__execute__(a, params, bindings, maxrows=maxrows) + headers.append(self.headers) + results.append(self.results) + self.headers = headers + self.results = results + else: + self.__execute__(sql, params, bindings, maxrows=maxrows) + return (self.headers, self.results) - def callproc(self, procname, params=None, bindings=None, maxrows=None): - """ execute a stored procedure """ - cur = self.begin() - try: - cur.callproc(procname, params=params, bindings=bindings, maxrows=maxrows) - finally: - self.commit(cur) - self.display() + def callproc(self, procname, params=None, bindings=None, maxrows=None): + """ execute a stored procedure """ + cur = self.begin() + try: + cur.callproc(procname, params=params, bindings=bindings, maxrows=maxrows) + finally: + self.commit(cur) + self.display() - def pk(self, table, owner=None, schema=None): - """ display the table's primary keys """ - cur = self.begin() - cur.primarykeys(schema, owner, table) - self.commit(cur) - self.display() + def pk(self, table, owner=None, schema=None): + """ display the table's primary keys """ + cur = self.begin() + cur.primarykeys(schema, owner, table) + self.commit(cur) + self.display() - def fk(self, primary_table=None, foreign_table=None, owner=None, schema=None): - """ display the table's foreign keys """ - cur = self.begin() - if primary_table and foreign_table: - cur.foreignkeys(schema, owner, primary_table, schema, owner, foreign_table) - elif primary_table: - cur.foreignkeys(schema, owner, primary_table, schema, owner, None) - elif foreign_table: - cur.foreignkeys(schema, owner, None, schema, owner, foreign_table) - self.commit(cur) - self.display() + def fk(self, primary_table=None, foreign_table=None, owner=None, schema=None): + """ display the table's foreign keys """ + cur = self.begin() + if primary_table and foreign_table: + cur.foreignkeys(schema, owner, primary_table, schema, owner, foreign_table) + elif primary_table: + cur.foreignkeys(schema, owner, primary_table, schema, owner, None) + elif foreign_table: + cur.foreignkeys(schema, owner, None, schema, owner, foreign_table) + self.commit(cur) + self.display() - def table(self, table=None, types=("TABLE",), owner=None, schema=None): - """If no table argument, displays a list of all tables. If a table argument, - displays the columns of the given table.""" - cur = self.begin() - if table: - cur.columns(schema, owner, table, None) - else: - cur.tables(schema, owner, None, types) - self.commit(cur) - self.display() + def table(self, table=None, types=("TABLE",), owner=None, schema=None): + """If no table argument, displays a list of all tables. If a table argument, + displays the columns of the given table.""" + cur = self.begin() + if table: + cur.columns(schema, owner, table, None) + else: + cur.tables(schema, owner, None, types) + self.commit(cur) + self.display() - def proc(self, proc=None, owner=None, schema=None): - """If no proc argument, displays a list of all procedures. If a proc argument, - displays the parameters of the given procedure.""" - cur = self.begin() - if proc: - cur.procedurecolumns(schema, owner, proc, None) - else: - cur.procedures(schema, owner, None) - self.commit(cur) - self.display() + def proc(self, proc=None, owner=None, schema=None): + """If no proc argument, displays a list of all procedures. If a proc argument, + displays the parameters of the given procedure.""" + cur = self.begin() + if proc: + cur.procedurecolumns(schema, owner, proc, None) + else: + cur.procedures(schema, owner, None) + self.commit(cur) + self.display() - def stat(self, table, qualifier=None, owner=None, unique=0, accuracy=0): - """ display the table's indicies """ - cur = self.begin() - cur.statistics(qualifier, owner, table, unique, accuracy) - self.commit(cur) - self.display() + def stat(self, table, qualifier=None, owner=None, unique=0, accuracy=0): + """ display the table's indicies """ + cur = self.begin() + cur.statistics(qualifier, owner, table, unique, accuracy) + self.commit(cur) + self.display() - def typeinfo(self, sqltype=None): - """ display the types available for the database """ - cur = self.begin() - cur.gettypeinfo(sqltype) - self.commit(cur) - self.display() + def typeinfo(self, sqltype=None): + """ display the types available for the database """ + cur = self.begin() + cur.gettypeinfo(sqltype) + self.commit(cur) + self.display() - def tabletypeinfo(self): - """ display the table types available for the database """ - cur = self.begin() - cur.gettabletypeinfo() - self.commit(cur) - self.display() + def tabletypeinfo(self): + """ display the table types available for the database """ + cur = self.begin() + cur.gettabletypeinfo() + self.commit(cur) + self.display() - def schema(self, table, full=0, sort=1, owner=None): - """Displays a Schema object for the table. If full is true, then generates - references to the table in addition to the standard fields. If sort is true, - sort all the items in the schema, else leave them in db dependent order.""" - print >> self.out, str(Schema(self, table, owner, full, sort)) + def schema(self, table, full=0, sort=1, owner=None): + """Displays a Schema object for the table. If full is true, then generates + references to the table in addition to the standard fields. If sort is true, + sort all the items in the schema, else leave them in db dependent order.""" + print >> self.out, str(Schema(self, table, owner, full, sort)) - def bulkcopy(self, dst, table, include=[], exclude=[], autobatch=0, executor=executor): - """Returns a Bulkcopy object using the given table.""" - if type(dst) == type(""): - dst = dbexts(dst, cfg=self.dbs) - bcp = Bulkcopy(dst, table, include=include, exclude=exclude, autobatch=autobatch, executor=executor) - return bcp + def bulkcopy(self, dst, table, include=[], exclude=[], autobatch=0, executor=executor): + """Returns a Bulkcopy object using the given table.""" + if type(dst) == type(""): + dst = dbexts(dst, cfg=self.dbs) + bcp = Bulkcopy(dst, table, include=include, exclude=exclude, autobatch=autobatch, executor=executor) + return bcp - def bcp(self, src, table, where='(1=1)', params=[], include=[], exclude=[], autobatch=0, executor=executor): - """Bulkcopy of rows from a src database to the current database for a given table and where clause.""" - if type(src) == type(""): - src = dbexts(src, cfg=self.dbs) - bcp = self.bulkcopy(self, table, include, exclude, autobatch, executor) - num = bcp.transfer(src, where, params) - return num + def bcp(self, src, table, where='(1=1)', params=[], include=[], exclude=[], autobatch=0, executor=executor): + """Bulkcopy of rows from a src database to the current database for a given table and where clause.""" + if type(src) == type(""): + src = dbexts(src, cfg=self.dbs) + bcp = self.bulkcopy(self, table, include, exclude, autobatch, executor) + num = bcp.transfer(src, where, params) + return num - def unload(self, filename, sql, delimiter=",", includeheaders=1): - """ Unloads the delimited results of the query to the file specified, optionally including headers. """ - u = Unload(self, filename, delimiter, includeheaders) - u.unload(sql) + def unload(self, filename, sql, delimiter=",", includeheaders=1): + """ Unloads the delimited results of the query to the file specified, optionally including headers. """ + u = Unload(self, filename, delimiter, includeheaders) + u.unload(sql) class Bulkcopy: - """The idea for a bcp class came from http://object-craft.com.au/projects/sybase""" - def __init__(self, dst, table, include=[], exclude=[], autobatch=0, executor=executor): - self.dst = dst - self.table = table - self.total = 0 - self.rows = [] - self.autobatch = autobatch - self.bindings = {} + """The idea for a bcp class came from http://object-craft.com.au/projects/sybase""" + def __init__(self, dst, table, include=[], exclude=[], autobatch=0, executor=executor): + self.dst = dst + self.table = table + self.total = 0 + self.rows = [] + self.autobatch = autobatch + self.bindings = {} - include = map(lambda x: x.lower(), include) - exclude = map(lambda x: x.lower(), exclude) + include = map(lambda x: x.lower(), include) + exclude = map(lambda x: x.lower(), exclude) - _verbose = self.dst.verbose - self.dst.verbose = 0 - try: - self.dst.table(self.table) - if self.dst.results: - colmap = {} - for a in self.dst.results: - colmap[a[3].lower()] = a[4] - cols = self.__filter__(colmap.keys(), include, exclude) - for a in zip(range(len(cols)), cols): - self.bindings[a[0]] = colmap[a[1]] - colmap = None - else: - cols = self.__filter__(include, include, exclude) - finally: - self.dst.verbose = _verbose + _verbose = self.dst.verbose + self.dst.verbose = 0 + try: + self.dst.table(self.table) + if self.dst.results: + colmap = {} + for a in self.dst.results: + colmap[a[3].lower()] = a[4] + cols = self.__filter__(colmap.keys(), include, exclude) + for a in zip(range(len(cols)), cols): + self.bindings[a[0]] = colmap[a[1]] + colmap = None + else: + cols = self.__filter__(include, include, exclude) + finally: + self.dst.verbose = _verbose - self.executor = executor(table, cols) + self.executor = executor(table, cols) - def __str__(self): - return "[%s].[%s]" % (self.dst, self.table) + def __str__(self): + return "[%s].[%s]" % (self.dst, self.table) - def __repr__(self): - return "[%s].[%s]" % (self.dst, self.table) + def __repr__(self): + return "[%s].[%s]" % (self.dst, self.table) - def __getattr__(self, name): - if name == 'columns': - return self.executor.cols + def __getattr__(self, name): + if name == 'columns': + return self.executor.cols - def __filter__(self, values, include, exclude): - cols = map(lambda col: col.lower(), values) - if exclude: - cols = filter(lambda x, ex=exclude: x not in ex, cols) - if include: - cols = filter(lambda x, inc=include: x in inc, cols) - return cols + def __filter__(self, values, include, exclude): + cols = map(lambda col: col.lower(), values) + if exclude: + cols = filter(lambda x, ex=exclude: x not in ex, cols) + if include: + cols = filter(lambda x, inc=include: x in inc, cols) + return cols - def format(self, column, type): - self.bindings[column] = type + def format(self, column, type): + self.bindings[column] = type - def done(self): - if len(self.rows) > 0: - return self.batch() - return 0 + def done(self): + if len(self.rows) > 0: + return self.batch() + return 0 - def batch(self): - self.executor.execute(self.dst, self.rows, self.bindings) - cnt = len(self.rows) - self.total += cnt - self.rows = [] - return cnt + def batch(self): + self.executor.execute(self.dst, self.rows, self.bindings) + cnt = len(self.rows) + self.total += cnt + self.rows = [] + return cnt - def rowxfer(self, line): - self.rows.append(line) - if self.autobatch: self.batch() + def rowxfer(self, line): + self.rows.append(line) + if self.autobatch: self.batch() - def transfer(self, src, where="(1=1)", params=[]): - sql = "select %s from %s where %s" % (", ".join(self.columns), self.table, where) - h, d = src.raw(sql, params) - if d: - map(self.rowxfer, d) - return self.done() - return 0 + def transfer(self, src, where="(1=1)", params=[]): + sql = "select %s from %s where %s" % (", ".join(self.columns), self.table, where) + h, d = src.raw(sql, params) + if d: + map(self.rowxfer, d) + return self.done() + return 0 class Unload: - """Unloads a sql statement to a file with optional formatting of each value.""" - def __init__(self, db, filename, delimiter=",", includeheaders=1): - self.db = db - self.filename = filename - self.delimiter = delimiter - self.includeheaders = includeheaders - self.formatters = {} + """Unloads a sql statement to a file with optional formatting of each value.""" + def __init__(self, db, filename, delimiter=",", includeheaders=1): + self.db = db + self.filename = filename + self.delimiter = delimiter + self.includeheaders = includeheaders + self.formatters = {} - def format(self, o): - if not o: - return "" - o = str(o) - if o.find(",") != -1: - o = "\"\"%s\"\"" % (o) - return o + def format(self, o): + if not o: + return "" + o = str(o) + if o.find(",") != -1: + o = "\"\"%s\"\"" % (o) + return o - def unload(self, sql, mode="w"): - headers, results = self.db.raw(sql) - w = open(self.filename, mode) - if self.includeheaders: - w.write("%s\n" % (self.delimiter.join(map(lambda x: x[0], headers)))) - if results: - for a in results: - w.write("%s\n" % (self.delimiter.join(map(self.format, a)))) - w.flush() - w.close() + def unload(self, sql, mode="w"): + headers, results = self.db.raw(sql) + w = open(self.filename, mode) + if self.includeheaders: + w.write("%s\n" % (self.delimiter.join(map(lambda x: x[0], headers)))) + if results: + for a in results: + w.write("%s\n" % (self.delimiter.join(map(self.format, a)))) + w.flush() + w.close() class Schema: - """Produces a Schema object which represents the database schema for a table""" - def __init__(self, db, table, owner=None, full=0, sort=1): - self.db = db - self.table = table - self.owner = owner - self.full = full - self.sort = sort - _verbose = self.db.verbose - self.db.verbose = 0 - try: - if table: self.computeschema() - finally: - self.db.verbose = _verbose + """Produces a Schema object which represents the database schema for a table""" + def __init__(self, db, table, owner=None, full=0, sort=1): + self.db = db + self.table = table + self.owner = owner + self.full = full + self.sort = sort + _verbose = self.db.verbose + self.db.verbose = 0 + try: + if table: self.computeschema() + finally: + self.db.verbose = _verbose - def computeschema(self): - self.db.table(self.table, owner=self.owner) - self.columns = [] - # (column name, type_name, size, nullable) - if self.db.results: - self.columns = map(lambda x: (x[3], x[5], x[6], x[10]), self.db.results) - if self.sort: self.columns.sort(lambda x, y: cmp(x[0], y[0])) + def computeschema(self): + self.db.table(self.table, owner=self.owner) + self.columns = [] + # (column name, type_name, size, nullable) + if self.db.results: + self.columns = map(lambda x: (x[3], x[5], x[6], x[10]), self.db.results) + if self.sort: self.columns.sort(lambda x, y: cmp(x[0], y[0])) - self.db.fk(None, self.table) - # (pk table name, pk column name, fk column name, fk name, pk name) - self.imported = [] - if self.db.results: - self.imported = map(lambda x: (x[2], x[3], x[7], x[11], x[12]), self.db.results) - if self.sort: self.imported.sort(lambda x, y: cmp(x[2], y[2])) + self.db.fk(None, self.table) + # (pk table name, pk column name, fk column name, fk name, pk name) + self.imported = [] + if self.db.results: + self.imported = map(lambda x: (x[2], x[3], x[7], x[11], x[12]), self.db.results) + if self.sort: self.imported.sort(lambda x, y: cmp(x[2], y[2])) - self.exported = [] - if self.full: - self.db.fk(self.table, None) - # (pk column name, fk table name, fk column name, fk name, pk name) - if self.db.results: - self.exported = map(lambda x: (x[3], x[6], x[7], x[11], x[12]), self.db.results) - if self.sort: self.exported.sort(lambda x, y: cmp(x[1], y[1])) + self.exported = [] + if self.full: + self.db.fk(self.table, None) + # (pk column name, fk table name, fk column name, fk name, pk name) + if self.db.results: + self.exported = map(lambda x: (x[3], x[6], x[7], x[11], x[12]), self.db.results) + if self.sort: self.exported.sort(lambda x, y: cmp(x[1], y[1])) - self.db.pk(self.table) - self.primarykeys = [] - if self.db.results: - # (column name, key_seq, pk name) - self.primarykeys = map(lambda x: (x[3], x[4], x[5]), self.db.results) - if self.sort: self.primarykeys.sort(lambda x, y: cmp(x[1], y[1])) + self.db.pk(self.table) + self.primarykeys = [] + if self.db.results: + # (column name, key_seq, pk name) + self.primarykeys = map(lambda x: (x[3], x[4], x[5]), self.db.results) + if self.sort: self.primarykeys.sort(lambda x, y: cmp(x[1], y[1])) - try: - self.indices = None - self.db.stat(self.table) - self.indices = [] - # (non-unique, name, type, pos, column name, asc) - if self.db.results: - idxdict = {} - # mxODBC returns a row of None's, so filter it out - idx = map(lambda x: (x[3], x[5].strip(), x[6], x[7], x[8]), filter(lambda x: x[5], self.db.results)) - def cckmp(x, y): - c = cmp(x[1], y[1]) - if c == 0: c = cmp(x[3], y[3]) - return c - # sort this regardless, this gets the indicies lined up - idx.sort(cckmp) - for a in idx: - if not idxdict.has_key(a[1]): - idxdict[a[1]] = [] - idxdict[a[1]].append(a) - self.indices = idxdict.values() - if self.sort: self.indices.sort(lambda x, y: cmp(x[0][1], y[0][1])) - except: - pass + try: + self.indices = None + self.db.stat(self.table) + self.indices = [] + # (non-unique, name, type, pos, column name, asc) + if self.db.results: + idxdict = {} + # mxODBC returns a row of None's, so filter it out + idx = map(lambda x: (x[3], x[5].strip(), x[6], x[7], x[8]), filter(lambda x: x[5], self.db.results)) + def cckmp(x, y): + c = cmp(x[1], y[1]) + if c == 0: c = cmp(x[3], y[3]) + return c + # sort this regardless, this gets the indicies lined up + idx.sort(cckmp) + for a in idx: + if not idxdict.has_key(a[1]): + idxdict[a[1]] = [] + idxdict[a[1]].append(a) + self.indices = idxdict.values() + if self.sort: self.indices.sort(lambda x, y: cmp(x[0][1], y[0][1])) + except: + pass - def __str__(self): - d = [] - d.append("Table") - d.append(" " + self.table) - d.append("\nPrimary Keys") - for a in self.primarykeys: - d.append(" %s {%s}" % (a[0], a[2])) - d.append("\nImported (Foreign) Keys") - for a in self.imported: - d.append(" %s (%s.%s) {%s}" % (a[2], a[0], a[1], a[3])) - if self.full: - d.append("\nExported (Referenced) Keys") - for a in self.exported: - d.append(" %s (%s.%s) {%s}" % (a[0], a[1], a[2], a[3])) - d.append("\nColumns") - for a in self.columns: - nullable = choose(a[3], "nullable", "non-nullable") - d.append(" %-20s %s(%s), %s" % (a[0], a[1], a[2], nullable)) - d.append("\nIndices") - if self.indices is None: - d.append(" (failed)") - else: - for a in self.indices: - unique = choose(a[0][0], "non-unique", "unique") - cname = ", ".join(map(lambda x: x[4], a)) - d.append(" %s index {%s} on (%s)" % (unique, a[0][1], cname)) - return "\n".join(d) + def __str__(self): + d = [] + d.append("Table") + d.append(" " + self.table) + d.append("\nPrimary Keys") + for a in self.primarykeys: + d.append(" %s {%s}" % (a[0], a[2])) + d.append("\nImported (Foreign) Keys") + for a in self.imported: + d.append(" %s (%s.%s) {%s}" % (a[2], a[0], a[1], a[3])) + if self.full: + d.append("\nExported (Referenced) Keys") + for a in self.exported: + d.append(" %s (%s.%s) {%s}" % (a[0], a[1], a[2], a[3])) + d.append("\nColumns") + for a in self.columns: + nullable = choose(a[3], "nullable", "non-nullable") + d.append(" %-20s %s(%s), %s" % (a[0], a[1], a[2], nullable)) + d.append("\nIndices") + if self.indices is None: + d.append(" (failed)") + else: + for a in self.indices: + unique = choose(a[0][0], "non-unique", "unique") + cname = ", ".join(map(lambda x: x[4], a)) + d.append(" %s index {%s} on (%s)" % (unique, a[0][1], cname)) + return "\n".join(d) class IniParser: - def __init__(self, cfg, key='name'): - self.key = key - self.records = {} - self.ctypeRE = re.compile("\[(jdbc|odbc|default)\]") - self.entryRE = re.compile("([a-zA-Z]+)[ \t]*=[ \t]*(.*)") - self.cfg = cfg - self.parse() + def __init__(self, cfg, key='name'): + self.key = key + self.records = {} + self.ctypeRE = re.compile("\[(jdbc|odbc|default)\]") + self.entryRE = re.compile("([a-zA-Z]+)[ \t]*=[ \t]*(.*)") + self.cfg = cfg + self.parse() - def parse(self): - fp = open(self.cfg, "r") - data = fp.readlines() - fp.close() - lines = filter(lambda x: len(x) > 0 and x[0] not in ['#', ';'], map(lambda x: x.strip(), data)) - current = None - for i in range(len(lines)): - line = lines[i] - g = self.ctypeRE.match(line) - if g: # a section header - current = {} - if not self.records.has_key(g.group(1)): - self.records[g.group(1)] = [] - self.records[g.group(1)].append(current) - else: - g = self.entryRE.match(line) - if g: - current[g.group(1)] = g.group(2) + def parse(self): + fp = open(self.cfg, "r") + data = fp.readlines() + fp.close() + lines = filter(lambda x: len(x) > 0 and x[0] not in ['#', ';'], map(lambda x: x.strip(), data)) + current = None + for i in range(len(lines)): + line = lines[i] + g = self.ctypeRE.match(line) + if g: # a section header + current = {} + if not self.records.has_key(g.group(1)): + self.records[g.group(1)] = [] + self.records[g.group(1)].append(current) + else: + g = self.entryRE.match(line) + if g: + current[g.group(1)] = g.group(2) - def __getitem__(self, (ctype, skey)): - if skey == self.key: return self.records[ctype][0][skey] - t = filter(lambda x, p=self.key, s=skey: x[p] == s, self.records[ctype]) - if not t or len(t) > 1: - raise KeyError, "invalid key ('%s', '%s')" % (ctype, skey) - return t[0] + def __getitem__(self, (ctype, skey)): + if skey == self.key: return self.records[ctype][0][skey] + t = filter(lambda x, p=self.key, s=skey: x[p] == s, self.records[ctype]) + if not t or len(t) > 1: + raise KeyError, "invalid key ('%s', '%s')" % (ctype, skey) + return t[0] def random_table_name(prefix, num_chars): - import random - d = [prefix, '_'] - i = 0 - while i < num_chars: - d.append(chr(int(100 * random.random()) % 26 + ord('A'))) - i += 1 - return "".join(d) + import random + d = [prefix, '_'] + i = 0 + while i < num_chars: + d.append(chr(int(100 * random.random()) % 26 + ord('A'))) + i += 1 + return "".join(d) class ResultSetRow: - def __init__(self, rs, row): - self.row = row - self.rs = rs - def __getitem__(self, i): - if type(i) == type(""): - i = self.rs.index(i) - return self.row[i] - def __getslice__(self, i, j): - if type(i) == type(""): i = self.rs.index(i) - if type(j) == type(""): j = self.rs.index(j) - return self.row[i:j] - def __len__(self): - return len(self.row) - def __repr__(self): - return str(self.row) + def __init__(self, rs, row): + self.row = row + self.rs = rs + def __getitem__(self, i): + if type(i) == type(""): + i = self.rs.index(i) + return self.row[i] + def __getslice__(self, i, j): + if type(i) == type(""): i = self.rs.index(i) + if type(j) == type(""): j = self.rs.index(j) + return self.row[i:j] + def __len__(self): + return len(self.row) + def __repr__(self): + return str(self.row) class ResultSet: - def __init__(self, headers, results=[]): - self.headers = map(lambda x: x.upper(), headers) - self.results = results - def index(self, i): - return self.headers.index(i.upper()) - def __getitem__(self, i): - return ResultSetRow(self, self.results[i]) - def __getslice__(self, i, j): - return map(lambda x, rs=self: ResultSetRow(rs, x), self.results[i:j]) - def __repr__(self): - return "<%s instance {cols [%d], rows [%d]} at %s>" % (self.__class__, len(self.headers), len(self.results), id(self)) + def __init__(self, headers, results=[]): + self.headers = map(lambda x: x.upper(), headers) + self.results = results + def index(self, i): + return self.headers.index(i.upper()) + def __getitem__(self, i): + return ResultSetRow(self, self.results[i]) + def __getslice__(self, i, j): + return map(lambda x, rs=self: ResultSetRow(rs, x), self.results[i:j]) + def __repr__(self): + return "<%s instance {cols [%d], rows [%d]} at %s>" % (self.__class__, len(self.headers), len(self.results), id(self)) Modified: trunk/jython/Lib/distutils/sysconfig.py =================================================================== --- trunk/jython/Lib/distutils/sysconfig.py 2008-02-27 12:52:31 UTC (rev 4184) +++ trunk/jython/Lib/distutils/sysconfig.py 2008-02-28 16:55:33 UTC (rev 4185) @@ -160,7 +160,7 @@ if os.environ.has_key('LDFLAGS'): ldshared = ldshared + ' ' + os.environ['LDFLAGS'] if basecflags: - opt = basecflags + ' ' + opt + opt = basecflags + ' ' + opt if os.environ.has_key('CFLAGS'): opt = opt + ' ' + os.environ['CFLAGS'] ldshared = ldshared + ' ' + os.environ['CFLAGS'] Modified: trunk/jython/Lib/isql.py =================================================================== --- trunk/jython/Lib/isql.py 2008-02-27 12:52:31 UTC (rev 4184) +++ trunk/jython/Lib/isql.py 2008-02-28 16:55:33 UTC (rev 4185) @@ -12,227 +12,227 @@ class IsqlExit(Exception): pass class Prompt: - """ - This class fixes a problem with the cmd.Cmd class since it uses an ivar 'prompt' - as opposed to a method 'prompt()'. To get around this, this class is plugged in - as a 'prompt' attribute and when invoked the '__str__' method is called which - figures out the appropriate prompt to display. I still think, even though this - is clever, the attribute version of 'prompt' is poor design. - """ - def __init__(self, isql): - self.isql = isql - def __str__(self): - prompt = "%s> " % (self.isql.db.dbname) - if len(self.isql.sqlbuffer) > 0: - prompt = "... " - return prompt - if os.name == 'java': - def __tojava__(self, cls): - import java.lang.String - if cls == java.lang.String: - return self.__str__() - return False + """ + This class fixes a problem with the cmd.Cmd class since it uses an ivar 'prompt' + as opposed to a method 'prompt()'. To get around this, this class is plugged in + as a 'prompt' attribute and when invoked the '__str__' method is called which + figures out the appropriate prompt to display. I still think, even though this + is clever, the attribute version of 'prompt' is poor design. + """ + def __init__(self, isql): + self.isql = isql + def __str__(self): + prompt = "%s> " % (self.isql.db.dbname) + if len(self.isql.sqlbuffer) > 0: + prompt = "... " + return prompt + if os.name == 'java': + def __tojava__(self, cls): + import java.lang.String + if cls == java.lang.String: + return self.__str__() + return False class IsqlCmd(cmd.Cmd): - def __init__(self, db=None, delimiter=";", comment=('#', '--')): - cmd.Cmd.__init__(self, completekey=None) - if db is None or type(db) == type(""): - self.db = dbexts.dbexts(db) - else: - self.db = db - self.kw = {} - self.sqlbuffer = [] - self.comment = comment - self.delimiter = delimiter - self.prompt = Prompt(self) - - def parseline(self, line): - command, arg, line = cmd.Cmd.parseline(self, line) - if command and command <> "EOF": - command = command.lower() - return command, arg, line + def __init__(self, db=None, delimiter=";", comment=('#', '--')): + cmd.Cmd.__init__(self, completekey=None) + if db is None or type(db) == type(""): + self.db = dbexts.dbexts(db) + else: + self.db = db + self.kw = {} + self.sqlbuffer = [] + self.comment = comment + self.delimiter = delimiter + self.prompt = Prompt(self) + + def parseline(self, line): + command, arg, line = cmd.Cmd.parseline(self, line) + if command and command <> "EOF": + command = command.lower() + return command, arg, line - def do_which(self, arg): - """\nPrints the current db connection parameters.\n""" - print self.db - return False + def do_which(self, arg): + """\nPrints the current db connection parameters.\n""" + print self.db + return False - def do_EOF(self, arg): - return False + def do_EOF(self, arg): + return False - def do_p(self, arg): - """\nExecute a python expression.\n""" - try: - exec arg.strip() in globals() - except: - print sys.exc_info()[1] - return False - - def do_column(self, arg): - """\nInstructions for column display.\n""" - return False - - def do_use(self, arg): - """\nUse a new database connection.\n""" - # this allows custom dbexts - self.db = self.db.__class__(arg.strip()) - return False + def do_p(self, arg): + """\nExecute a python expression.\n""" + try: + exec arg.strip() in globals() + except: + print sys.exc_info()[1] + return False + + def do_column(self, arg): + """\nInstructions for column display.\n""" + return False + + def do_use(self, arg): + """\nUse a new database connection.\n""" + # this allows custom dbexts + self.db = self.db.__class__(arg.strip()) + return False - def do_table(self, arg): - """\nPrints table meta-data. If no table name, prints all tables.\n""" - if len(arg.strip()): - self.db.table(arg, **self.kw) - else: - self.db.table(None, **self.kw) - return False + def do_table(self, arg): + """\nPrints table meta-data. If no table name, prints all tables.\n""" + if len(arg.strip()): + self.db.table(arg, **self.kw) + else: + self.db.table(None, **self.kw) + return False - def do_proc(self, arg): - """\nPrints store procedure meta-data.\n""" - if len(arg.strip()): - self.db.proc(arg, **self.kw) - else: - self.db.proc(None, **self.kw) - return False + def do_proc(self, arg): + """\nPrints store procedure meta-data.\n""" + if len(arg.strip()): + self.db.proc(arg, **self.kw) + else: + self.db.proc(None, **self.kw) + return False - def do_schema(self, arg): - """\nPrints schema information.\n""" - print - self.db.schema(arg) - print - return False + def do_schema(self, arg): + """\nPrints schema information.\n""" + print + self.db.schema(arg) + print + return False - def do_delimiter(self, arg): - """\nChange the delimiter.\n""" - delimiter = arg.strip() - if len(delimiter) > 0: - self.delimiter = delimiter - - def do_o(self, arg): - """\nSet the output.\n""" - if not arg: - fp = self.db.out - try: - if fp: - fp.close() - finally: - self.db.out = None - else: - fp = open(arg, "w") - self.db.out = fp + def do_delimiter(self, arg): + """\nChange the delimiter.\n""" + delimiter = arg.strip() + if len(delimiter) > 0: + self.delimiter = delimiter + + def do_o(self, arg): + """\nSet the output.\n""" + if not arg: + fp = self.db.out + try: + ... [truncated message content] |