Fix cursor.scroll() to calculate row count correctly in all situations and

add test cases to verify that it does in fact work.
This commit is contained in:
Anthony Tuininga 2017-01-12 15:22:29 -07:00
parent 0e50be2a08
commit ab42359f26
2 changed files with 70 additions and 34 deletions

View File

@ -26,6 +26,7 @@ typedef struct {
int outputSize;
int outputSizeColumn;
ub8 rowCount;
ub8 bufferMinRow;
ub4 bufferRowCount;
ub4 bufferRowIndex;
int statementType;
@ -552,8 +553,9 @@ static int Cursor_SetRowCount(
self->rowCount = 0;
self->hasRowsToFetch = 0;
if (self->statementType == OCI_STMT_SELECT) {
self->bufferMinRow = 0;
self->bufferRowCount = 0;
self->bufferRowIndex = self->fetchArraySize;
self->bufferRowIndex = 0;
self->hasRowsToFetch = 1;
} else if (self->statementType == OCI_STMT_INSERT ||
self->statementType == OCI_STMT_UPDATE ||
@ -1913,7 +1915,8 @@ static int Cursor_InternalFetch(
"Cursor_InternalFetch(): get rows fetched") < 0)
return -1;
// reset buffer row index
// set buffer row info
self->bufferMinRow = self->rowCount + 1;
self->bufferRowIndex = 0;
return 0;
@ -2099,8 +2102,8 @@ static PyObject *Cursor_Scroll(
PyObject *keywordArgs) // keyword arguments
{
static char *keywordList[] = { "value", "mode", NULL };
ub8 desiredRow, minRowInBuffers, maxRowInBuffers;
ub4 fetchMode, numRows, currentPosition;
ub8 desiredRow;
sword status;
char *mode;
int value;
@ -2113,12 +2116,10 @@ static PyObject *Cursor_Scroll(
return NULL;
// validate mode
if (!mode) {
fetchMode = OCI_FETCH_RELATIVE;
desiredRow = self->rowCount + value;
} else if (strcmp(mode, "relative") == 0) {
if (!mode || strcmp(mode, "relative") == 0) {
fetchMode = OCI_FETCH_RELATIVE;
desiredRow = self->rowCount + value;
value = desiredRow - (self->bufferMinRow + self->bufferRowCount - 1);
} else if (strcmp(mode, "absolute") == 0) {
fetchMode = OCI_FETCH_ABSOLUTE;
desiredRow = value;
@ -2139,19 +2140,11 @@ static PyObject *Cursor_Scroll(
return NULL;
// determine if a fetch is actually required; "last" is always fetched
if (fetchMode != OCI_FETCH_LAST && self->bufferRowCount > 0) {
minRowInBuffers = self->rowCount - self->bufferRowIndex;
maxRowInBuffers = self->rowCount + self->bufferRowCount -
self->bufferRowIndex - 1;
if (self->bufferRowIndex == self->bufferRowCount) {
minRowInBuffers += 1;
maxRowInBuffers += 1;
}
if (desiredRow >= minRowInBuffers && desiredRow <= maxRowInBuffers) {
self->bufferRowIndex = desiredRow - minRowInBuffers;
self->rowCount = desiredRow - 1;
Py_RETURN_NONE;
}
if (fetchMode != OCI_FETCH_LAST && desiredRow >= self->bufferMinRow &&
desiredRow < self->bufferMinRow + self->bufferRowCount) {
self->bufferRowIndex = desiredRow - self->bufferMinRow;
self->rowCount = desiredRow - 1;
Py_RETURN_NONE;
}
// perform fetch; when fetching the last row, only fetch a single row
@ -2160,21 +2153,12 @@ static PyObject *Cursor_Scroll(
status = OCIStmtFetch2(self->handle, self->environment->errorHandle,
numRows, fetchMode, value, OCI_DEFAULT);
Py_END_ALLOW_THREADS
if (status == OCI_NO_DATA) {
if (fetchMode == OCI_FETCH_FIRST || fetchMode == OCI_FETCH_LAST) {
self->hasRowsToFetch = 0;
self->rowCount = 0;
self->bufferRowCount = 0;
self->bufferRowIndex = 0;
} else {
PyErr_SetString(PyExc_IndexError,
"requested scroll operation would leave result set");
return NULL;
}
} else if (Environment_CheckForError(self->environment, status,
if (status == OCI_NO_DATA || fetchMode == OCI_FETCH_LAST)
self->hasRowsToFetch = 0;
else if (Environment_CheckForError(self->environment, status,
"Cursor_Scroll(): fetch") < 0)
return NULL;
self->hasRowsToFetch = 1;
else self->hasRowsToFetch = 1;
// determine the number of rows actually fetched
status = OCIAttrGet(self->handle, OCI_HTYPE_STMT, &self->bufferRowCount, 0,
@ -2183,6 +2167,20 @@ static PyObject *Cursor_Scroll(
"Cursor_Scroll(): get rows fetched") < 0)
return NULL;
// handle the case when no rows have been retrieved
if (self->bufferRowCount == 0) {
if (fetchMode != OCI_FETCH_FIRST && fetchMode != OCI_FETCH_LAST) {
PyErr_SetString(PyExc_IndexError,
"requested scroll operation would leave result set");
return NULL;
}
self->rowCount = 0;
self->bufferMinRow = 0;
self->bufferRowCount = 0;
self->bufferRowIndex = 0;
Py_RETURN_NONE;
}
// determine the current position of the cursor
status = OCIAttrGet(self->handle, OCI_HTYPE_STMT, &currentPosition, 0,
OCI_ATTR_CURRENT_POSITION, self->environment->errorHandle);
@ -2192,8 +2190,8 @@ static PyObject *Cursor_Scroll(
// reset buffer row index and row count
self->rowCount = currentPosition - self->bufferRowCount;
self->bufferMinRow = self->rowCount + 1;
self->bufferRowIndex = 0;
Py_RETURN_NONE;
}

View File

@ -381,6 +381,44 @@ class TestCursor(BaseTestCase):
self.assertEqual(row[0], 3.75)
self.assertEqual(cursor.rowcount, 3)
def testScrollNoRows(self):
"""test scrolling when there are no rows"""
self.cursor.execute("truncate table TestTempTable")
cursor = self.connection.cursor(scrollable = True)
cursor.execute("select * from TestTempTable")
cursor.scroll(mode = "last")
self.assertEqual(cursor.fetchall(), [])
cursor.scroll(mode = "first")
self.assertEqual(cursor.fetchall(), [])
self.assertRaises(IndexError, cursor.scroll, 1, mode = "absolute")
def testScrollDifferingArrayAndFetchSizes(self):
"""test scrolling with differing array sizes and fetch array sizes"""
self.cursor.execute("truncate table TestTempTable")
for i in range(30):
self.cursor.execute("insert into TestTempTable values (:1, null)",
(i + 1,))
for arraySize in range(1, 6):
cursor = self.connection.cursor(scrollable = True)
cursor.arraysize = arraySize
cursor.execute("select IntCol from TestTempTable order by IntCol")
for numRows in range(1, arraySize + 1):
cursor.scroll(15, "absolute")
rows = cursor.fetchmany(numRows)
self.assertEqual(rows[0][0], 15)
self.assertEqual(cursor.rowcount, 15 + numRows - 1)
cursor.scroll(9)
rows = cursor.fetchmany(numRows)
numRowsFetched = len(rows)
self.assertEqual(rows[0][0], 15 + numRows + 8)
self.assertEqual(cursor.rowcount,
15 + numRows + numRowsFetched + 7)
cursor.scroll(-12)
rows = cursor.fetchmany(numRows)
self.assertEqual(rows[0][0], 15 + numRows + numRowsFetched - 5)
self.assertEqual(cursor.rowcount,
15 + numRows + numRowsFetched + numRows - 6)
def testSetInputSizesMultipleMethod(self):
"""test setting input sizes with both positional and keyword args"""
self.assertRaises(cx_Oracle.InterfaceError,