Fix cursor.scroll() to calculate row count correctly in all situations and

add test cases to verify that it does in fact work.
This commit is contained in:
Anthony Tuininga 2017-01-12 15:22:29 -07:00
parent 0e50be2a08
commit ab42359f26
2 changed files with 70 additions and 34 deletions

View File

@ -26,6 +26,7 @@ typedef struct {
int outputSize; int outputSize;
int outputSizeColumn; int outputSizeColumn;
ub8 rowCount; ub8 rowCount;
ub8 bufferMinRow;
ub4 bufferRowCount; ub4 bufferRowCount;
ub4 bufferRowIndex; ub4 bufferRowIndex;
int statementType; int statementType;
@ -552,8 +553,9 @@ static int Cursor_SetRowCount(
self->rowCount = 0; self->rowCount = 0;
self->hasRowsToFetch = 0; self->hasRowsToFetch = 0;
if (self->statementType == OCI_STMT_SELECT) { if (self->statementType == OCI_STMT_SELECT) {
self->bufferMinRow = 0;
self->bufferRowCount = 0; self->bufferRowCount = 0;
self->bufferRowIndex = self->fetchArraySize; self->bufferRowIndex = 0;
self->hasRowsToFetch = 1; self->hasRowsToFetch = 1;
} else if (self->statementType == OCI_STMT_INSERT || } else if (self->statementType == OCI_STMT_INSERT ||
self->statementType == OCI_STMT_UPDATE || self->statementType == OCI_STMT_UPDATE ||
@ -1913,7 +1915,8 @@ static int Cursor_InternalFetch(
"Cursor_InternalFetch(): get rows fetched") < 0) "Cursor_InternalFetch(): get rows fetched") < 0)
return -1; return -1;
// reset buffer row index // set buffer row info
self->bufferMinRow = self->rowCount + 1;
self->bufferRowIndex = 0; self->bufferRowIndex = 0;
return 0; return 0;
@ -2099,8 +2102,8 @@ static PyObject *Cursor_Scroll(
PyObject *keywordArgs) // keyword arguments PyObject *keywordArgs) // keyword arguments
{ {
static char *keywordList[] = { "value", "mode", NULL }; static char *keywordList[] = { "value", "mode", NULL };
ub8 desiredRow, minRowInBuffers, maxRowInBuffers;
ub4 fetchMode, numRows, currentPosition; ub4 fetchMode, numRows, currentPosition;
ub8 desiredRow;
sword status; sword status;
char *mode; char *mode;
int value; int value;
@ -2113,12 +2116,10 @@ static PyObject *Cursor_Scroll(
return NULL; return NULL;
// validate mode // validate mode
if (!mode) { if (!mode || strcmp(mode, "relative") == 0) {
fetchMode = OCI_FETCH_RELATIVE;
desiredRow = self->rowCount + value;
} else if (strcmp(mode, "relative") == 0) {
fetchMode = OCI_FETCH_RELATIVE; fetchMode = OCI_FETCH_RELATIVE;
desiredRow = self->rowCount + value; desiredRow = self->rowCount + value;
value = desiredRow - (self->bufferMinRow + self->bufferRowCount - 1);
} else if (strcmp(mode, "absolute") == 0) { } else if (strcmp(mode, "absolute") == 0) {
fetchMode = OCI_FETCH_ABSOLUTE; fetchMode = OCI_FETCH_ABSOLUTE;
desiredRow = value; desiredRow = value;
@ -2139,19 +2140,11 @@ static PyObject *Cursor_Scroll(
return NULL; return NULL;
// determine if a fetch is actually required; "last" is always fetched // determine if a fetch is actually required; "last" is always fetched
if (fetchMode != OCI_FETCH_LAST && self->bufferRowCount > 0) { if (fetchMode != OCI_FETCH_LAST && desiredRow >= self->bufferMinRow &&
minRowInBuffers = self->rowCount - self->bufferRowIndex; desiredRow < self->bufferMinRow + self->bufferRowCount) {
maxRowInBuffers = self->rowCount + self->bufferRowCount - self->bufferRowIndex = desiredRow - self->bufferMinRow;
self->bufferRowIndex - 1; self->rowCount = desiredRow - 1;
if (self->bufferRowIndex == self->bufferRowCount) { Py_RETURN_NONE;
minRowInBuffers += 1;
maxRowInBuffers += 1;
}
if (desiredRow >= minRowInBuffers && desiredRow <= maxRowInBuffers) {
self->bufferRowIndex = desiredRow - minRowInBuffers;
self->rowCount = desiredRow - 1;
Py_RETURN_NONE;
}
} }
// perform fetch; when fetching the last row, only fetch a single row // perform fetch; when fetching the last row, only fetch a single row
@ -2160,21 +2153,12 @@ static PyObject *Cursor_Scroll(
status = OCIStmtFetch2(self->handle, self->environment->errorHandle, status = OCIStmtFetch2(self->handle, self->environment->errorHandle,
numRows, fetchMode, value, OCI_DEFAULT); numRows, fetchMode, value, OCI_DEFAULT);
Py_END_ALLOW_THREADS Py_END_ALLOW_THREADS
if (status == OCI_NO_DATA) { if (status == OCI_NO_DATA || fetchMode == OCI_FETCH_LAST)
if (fetchMode == OCI_FETCH_FIRST || fetchMode == OCI_FETCH_LAST) { self->hasRowsToFetch = 0;
self->hasRowsToFetch = 0; else if (Environment_CheckForError(self->environment, status,
self->rowCount = 0;
self->bufferRowCount = 0;
self->bufferRowIndex = 0;
} else {
PyErr_SetString(PyExc_IndexError,
"requested scroll operation would leave result set");
return NULL;
}
} else if (Environment_CheckForError(self->environment, status,
"Cursor_Scroll(): fetch") < 0) "Cursor_Scroll(): fetch") < 0)
return NULL; return NULL;
self->hasRowsToFetch = 1; else self->hasRowsToFetch = 1;
// determine the number of rows actually fetched // determine the number of rows actually fetched
status = OCIAttrGet(self->handle, OCI_HTYPE_STMT, &self->bufferRowCount, 0, status = OCIAttrGet(self->handle, OCI_HTYPE_STMT, &self->bufferRowCount, 0,
@ -2183,6 +2167,20 @@ static PyObject *Cursor_Scroll(
"Cursor_Scroll(): get rows fetched") < 0) "Cursor_Scroll(): get rows fetched") < 0)
return NULL; return NULL;
// handle the case when no rows have been retrieved
if (self->bufferRowCount == 0) {
if (fetchMode != OCI_FETCH_FIRST && fetchMode != OCI_FETCH_LAST) {
PyErr_SetString(PyExc_IndexError,
"requested scroll operation would leave result set");
return NULL;
}
self->rowCount = 0;
self->bufferMinRow = 0;
self->bufferRowCount = 0;
self->bufferRowIndex = 0;
Py_RETURN_NONE;
}
// determine the current position of the cursor // determine the current position of the cursor
status = OCIAttrGet(self->handle, OCI_HTYPE_STMT, &currentPosition, 0, status = OCIAttrGet(self->handle, OCI_HTYPE_STMT, &currentPosition, 0,
OCI_ATTR_CURRENT_POSITION, self->environment->errorHandle); OCI_ATTR_CURRENT_POSITION, self->environment->errorHandle);
@ -2192,8 +2190,8 @@ static PyObject *Cursor_Scroll(
// reset buffer row index and row count // reset buffer row index and row count
self->rowCount = currentPosition - self->bufferRowCount; self->rowCount = currentPosition - self->bufferRowCount;
self->bufferMinRow = self->rowCount + 1;
self->bufferRowIndex = 0; self->bufferRowIndex = 0;
Py_RETURN_NONE; Py_RETURN_NONE;
} }

View File

@ -381,6 +381,44 @@ class TestCursor(BaseTestCase):
self.assertEqual(row[0], 3.75) self.assertEqual(row[0], 3.75)
self.assertEqual(cursor.rowcount, 3) self.assertEqual(cursor.rowcount, 3)
def testScrollNoRows(self):
"""test scrolling when there are no rows"""
self.cursor.execute("truncate table TestTempTable")
cursor = self.connection.cursor(scrollable = True)
cursor.execute("select * from TestTempTable")
cursor.scroll(mode = "last")
self.assertEqual(cursor.fetchall(), [])
cursor.scroll(mode = "first")
self.assertEqual(cursor.fetchall(), [])
self.assertRaises(IndexError, cursor.scroll, 1, mode = "absolute")
def testScrollDifferingArrayAndFetchSizes(self):
"""test scrolling with differing array sizes and fetch array sizes"""
self.cursor.execute("truncate table TestTempTable")
for i in range(30):
self.cursor.execute("insert into TestTempTable values (:1, null)",
(i + 1,))
for arraySize in range(1, 6):
cursor = self.connection.cursor(scrollable = True)
cursor.arraysize = arraySize
cursor.execute("select IntCol from TestTempTable order by IntCol")
for numRows in range(1, arraySize + 1):
cursor.scroll(15, "absolute")
rows = cursor.fetchmany(numRows)
self.assertEqual(rows[0][0], 15)
self.assertEqual(cursor.rowcount, 15 + numRows - 1)
cursor.scroll(9)
rows = cursor.fetchmany(numRows)
numRowsFetched = len(rows)
self.assertEqual(rows[0][0], 15 + numRows + 8)
self.assertEqual(cursor.rowcount,
15 + numRows + numRowsFetched + 7)
cursor.scroll(-12)
rows = cursor.fetchmany(numRows)
self.assertEqual(rows[0][0], 15 + numRows + numRowsFetched - 5)
self.assertEqual(cursor.rowcount,
15 + numRows + numRowsFetched + numRows - 6)
def testSetInputSizesMultipleMethod(self): def testSetInputSizesMultipleMethod(self):
"""test setting input sizes with both positional and keyword args""" """test setting input sizes with both positional and keyword args"""
self.assertRaises(cx_Oracle.InterfaceError, self.assertRaises(cx_Oracle.InterfaceError,