diff --git a/test/test_3000_subscription.py b/test/test_3000_subscription.py index 595ebe0..daa5c52 100644 --- a/test/test_3000_subscription.py +++ b/test/test_3000_subscription.py @@ -7,39 +7,60 @@ """ import threading +import unittest import cx_Oracle as oracledb import test_env -class SubscriptionData(object): +class SubscriptionData: def __init__(self, num_messages_expected): self.condition = threading.Condition() self.num_messages_expected = num_messages_expected self.num_messages_received = 0 + + def _process_message(self, message): + pass + + def callback_handler(self, message): + if message.type != oracledb.EVENT_DEREG: + self._process_message(message) + self.num_messages_received += 1 + if message.type == oracledb.EVENT_DEREG or \ + self.num_messages_received == self.num_messages_expected: + with self.condition: + self.condition.notify() + + def wait_for_messages(self): + if self.num_messages_received < self.num_messages_expected: + with self.condition: + self.condition.wait(10) + + +class AQSubscriptionData(SubscriptionData): + pass + + +class DMLSubscriptionData(SubscriptionData): + + def __init__(self, num_messages_expected): + super().__init__(num_messages_expected) self.table_operations = [] self.row_operations = [] self.rowids = [] - def CallbackHandler(self, message): - if message.type != oracledb.EVENT_DEREG: - table, = message.tables - self.table_operations.append(table.operation) - for row in table.rows: - self.row_operations.append(row.operation) - self.rowids.append(row.rowid) - self.num_messages_received += 1 - if message.type == oracledb.EVENT_DEREG or \ - self.num_messages_received == self.num_messages_expected: - self.condition.acquire() - self.condition.notify() - self.condition.release() + def _process_message(self, message): + table, = message.tables + self.table_operations.append(table.operation) + for row in table.rows: + self.row_operations.append(row.operation) + self.rowids.append(row.rowid) class TestCase(test_env.BaseTestCase): - def test_3000_subscription(self): - "3000 - test Subscription for insert, update, delete and truncate" + def test_3000_dml_subscription(self): + "3000 - test subscription for insert, update, delete and truncate" # skip if running on the Oracle Cloud, which does not support # subscriptions currently @@ -67,9 +88,9 @@ class TestCase(test_env.BaseTestCase): rowids = [] # set up subscription - data = SubscriptionData(5) + data = DMLSubscriptionData(5) connection = test_env.get_connection(threaded=True, events=True) - sub = connection.subscribe(callback=data.CallbackHandler, + sub = connection.subscribe(callback=data.callback_handler, timeout=10, qos=oracledb.SUBSCR_QOS_ROWIDS) sub.registerquery("select * from TestTempTable") connection.autocommit = True @@ -105,8 +126,7 @@ class TestCase(test_env.BaseTestCase): cursor.execute("truncate table TestTempTable") # wait for all messages to be sent - data.condition.acquire() - data.condition.wait(10) + data.wait_for_messages() # verify the correct messages were sent self.assertEqual(data.table_operations, table_operations) @@ -134,5 +154,30 @@ class TestCase(test_env.BaseTestCase): self.assertRaises(oracledb.ProgrammingError, connection.subscribe, client_initiated=True, clientInitiated=True) + @unittest.skip("multiple subscriptions cannot be created simultaneously") + def test_3002_aq_subscription(self): + "3002 - test subscription for AQ" + + # create queue and clear it of all messages + queue = self.connection.queue("TEST_RAW_QUEUE") + queue.deqoptions.wait = oracledb.DEQ_NO_WAIT + while queue.deqone(): + pass + self.connection.commit() + + # set up subscription + data = AQSubscriptionData(1) + connection = test_env.get_connection(events=True) + sub = connection.subscribe(namespace=oracledb.SUBSCR_NAMESPACE_AQ, + name=queue.name, timeout=10, + callback=data.callback_handler) + + # enqueue a message + queue.enqone(self.connection.msgproperties(payload="Some data")) + self.connection.commit() + + # wait for all messages to be sent + data.wait_for_messages() + if __name__ == "__main__": test_env.run_test_cases()