Diff of /MySQLdb/cursors.py [24fa6a] .. [80164e] Maximize Restore

  Switch to side-by-side view

--- a/MySQLdb/cursors.py
+++ b/MySQLdb/cursors.py
@@ -13,6 +13,7 @@
 import re
 import sys
 import weakref
+from MySQLdb.converters import get_codec, tuple_row_decoder
 
 INSERT_VALUES = re.compile(r"(?P<start>.+values\s*)"
                            r"(?P<values>\(((?<!\\)'[^\)]*?\)[^\)]*(?<!\\)?'|[^\(\)]|(?:\([^\)]*\)))+\))"
@@ -39,8 +40,7 @@
     _defer_warnings = False
     _fetch_type = None
 
-    def __init__(self, connection, encoders):
-        from MySQLdb.converters import default_decoders
+    def __init__(self, connection, encoders, decoders):
         self.connection = weakref.proxy(connection)
         self.description = None
         self.description_flags = None
@@ -54,17 +54,41 @@
         self._warnings = 0
         self._info = None
         self.rownumber = None
-        self._encoders = encoders
-
+        self.maxrows = 0
+        self.encoders = encoders
+        self.decoders = decoders
+        self._row_decoders = ()
+        self.row_decoder = tuple_row_decoder
+
+    def _flush(self):
+        """_flush() reads to the end of the current result set, buffering what
+        it can, and then releases the result set."""
+        if self._result:
+            for row in self._result:
+                pass
+            self._result = None
+    
     def __del__(self):
         self.close()
         self.errorhandler = None
         self._result = None
 
+    def _reset(self):
+        while True:
+            if self._result:
+                for row in self._result:
+                    pass
+                self._result = None
+            if not self.nextset():
+                break
+        del self.messages[:]
+            
     def close(self):
         """Close the cursor. No further queries will be possible."""
         if not self.connection:
             return
+        
+        self._flush()
         try:
             while self.nextset():
                 pass
@@ -106,22 +130,21 @@
         num_rows = connection.next_result()
         if num_rows == -1:
             return None
-        self._do_get_result()
-        self._post_get_result()
-        self._warning_check()
-        return True
-
-    def _do_get_result(self):
-        """Get the result from the last query."""
-        connection = self._get_db()
-        self._result = self._get_result()
-        self.rowcount = connection.affected_rows()
+        result = connection.use_result()
+        self._result = result
+        if result:
+            self.field_flags = result.field_flags()
+            self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ]
+            self.description = result.describe()
+        else:
+            self._row_decoders = self.field_flags = ()
+            self.description = None
+        self.rowcount = -1 #connection.affected_rows()
         self.rownumber = 0
-        self.description = self._result and self._result.describe() or None
-        self.description_flags = self._result and self._result.field_flags() or None
         self.lastrowid = connection.insert_id()
         self._warnings = connection.warning_count()
-        self._info = connection.info()
+        self._info = connection.info()        
+        return True
     
     def setinputsizes(self, *args):
         """Does nothing, required by DB API."""
@@ -150,15 +173,15 @@
         Returns long integer rows affected, if any
 
         """
-        del self.messages[:]
         db = self._get_db()
+        self._reset()
         charset = db.character_set_name()
         if isinstance(query, unicode):
             query = query.encode(charset)
         try:
             if args is not None:
                 query = query % tuple(map(self.connection.literal, args))
-            result = self._query(query)
+            self._query(query)
         except TypeError, msg:
             if msg.args[0] in ("not enough arguments for format string",
                                "not all arguments converted"):
@@ -173,10 +196,9 @@
             self.messages.append((exc, value))
             self.errorhandler(self, exc, value)
             
-        self._executed = query
         if not self._defer_warnings:
             self._warning_check()
-        return result
+        return None
 
     def executemany(self, query, args):
         """Execute a multi-row query.
@@ -197,8 +219,8 @@
         execute().
 
         """
-        del self.messages[:]
         db = self._get_db()
+        self._reset()
         if not args:
             return
         charset = self.connection.character_set_name()
@@ -216,8 +238,7 @@
         try:
             sql_params = ( values % tuple(map(self.connection.literal, row)) for row in args )
             multirow_query = '\n'.join([start, ',\n'.join(sql_params), end])
-            self._executed = multirow_query
-            self.rowcount = int(self._query(multirow_query))
+            self._query(multirow_query)
 
         except TypeError, msg:
             if msg.args[0] in ("not enough arguments for format string",
@@ -234,7 +255,7 @@
         
         if not self._defer_warnings:
             self._warning_check()
-        return self.rowcount
+        return None
     
     def callproc(self, procname, args=()):
         """Execute stored procedure procname with args
@@ -283,71 +304,62 @@
         if isinstance(query, unicode):
             query = query.encode(charset)
         self._query(query)
-        self._executed = query
         if not self._defer_warnings:
             self._warning_check()
         return args
-    
-    def _do_query(self, query):
-        """Low-levey query wrapper. Overridden by MixIns."""
+
+    def __iter__(self):
+        return iter(self.fetchone, None)
+
+    def _query(self, query):
+        """Low-level; executes query, gets result, sets up decoders."""
         connection = self._get_db()
+        self._flush()
         self._executed = query
         connection.query(query)
-        self._do_get_result()
-        return self.rowcount
-
-    def _fetch_row(self, size=1):
-        """Low-level fetch_row wrapper."""
-        if not self._result:
-            return ()
-        return self._result.fetch_row(size, self._fetch_type)
-
-    def __iter__(self):
-        return iter(self.fetchone, None)
-
-    def _get_result(self):
-        """Low-level; uses mysql_store_result()"""
-        return self._get_db().store_result()
-
-    def _query(self, query):
-        """Low-level; executes query, gets result, and returns rowcount."""
-        rowcount = self._do_query(query)
-        self._post_get_result()
-        return rowcount
-
-    def _post_get_result(self):
-        """Low-level"""
-        self._rows = self._fetch_row(0)
-        self._result = None
-
+        result = connection.use_result()
+        self._result = result
+        if result:
+            self.field_flags = result.field_flags()
+            self._row_decoders = [ get_codec(field, self.decoders) for field in result.fields ]
+            self.description = result.describe()
+        else:
+            self._row_decoders = self.field_flags = ()
+            self.description = None
+        self.rowcount = -1 #connection.affected_rows()
+        self.rownumber = 0
+        self.lastrowid = connection.insert_id()
+        self._warnings = connection.warning_count()
+        self._info = connection.info()
+    
     def fetchone(self):
         """Fetches a single row from the cursor. None indicates that
         no more rows are available."""
         self._check_executed()
-        if self.rownumber >= len(self._rows):
-            return None
-        result = self._rows[self.rownumber]
-        self.rownumber += 1
-        return result
+        row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row())
+        return row
 
     def fetchmany(self, size=None):
         """Fetch up to size rows from the cursor. Result set may be smaller
         than size. If size is not defined, cursor.arraysize is used."""
         self._check_executed()
-        end = self.rownumber + (size or self.arraysize)
-        result = self._rows[self.rownumber:end]
-        self.rownumber = min(end, len(self._rows))
-        return result
+        if size is None:
+            size = self.arraysize
+        rows = []
+        for i in range(size):
+            row = self.row_decoder(self._row_decoders, self._result.simple_fetch_row())
+            if row is None: break
+            rows.append(row)
+        return rows
 
     def fetchall(self):
-        """Fetchs all available rows from the cursor."""
+        """Fetches all available rows from the cursor."""
         self._check_executed()
-        if self.rownumber:
-            result = self._rows[self.rownumber:]
+        if self._result:
+            rows = [ self.row_decoder(self._row_decoders, row) for row in self._result ]
         else:
-            result = self._rows
-        self.rownumber = len(self._rows)
-        return result
+            rows = []
+        return rows
     
     def scroll(self, value, mode='relative'):
         """Scroll the cursor in the result set to a new position according
@@ -368,9 +380,3 @@
             self.errorhandler(self, IndexError, "out of range")
         self.rownumber = row
 
-    def __iter__(self):
-        self._check_executed()
-        result = self.rownumber and self._rows[self.rownumber:] or self._rows
-        return iter(result)
-    
-    _fetch_type = 0