diff --git a/connectors/sources/azure_blob_storage.py b/connectors/sources/azure_blob_storage.py index b46b4df2f..6d501baf5 100644 --- a/connectors/sources/azure_blob_storage.py +++ b/connectors/sources/azure_blob_storage.py @@ -185,6 +185,7 @@ async def get_content(self, blob, timestamp=None, doit=None): if not self.can_file_be_downloaded(file_extension, filename, file_size): return + self._logger.debug(f"Downloading content for file: {filename}") document = {"_id": blob["id"], "_timestamp": blob["_timestamp"]} return await self.download_and_extract_file( document, @@ -194,6 +195,9 @@ async def get_content(self, blob, timestamp=None, doit=None): ) async def blob_download_func(self, blob_name, container_name): + self._logger.debug( + f"Downloading content for blob: {blob_name} from {container_name} container" + ) async with BlobClient.from_connection_string( conn_str=self.connection_string, container_name=container_name, @@ -212,6 +216,7 @@ async def get_container(self, container_list): Yields: dictionary: Container document with name & metadata """ + self._logger.debug("Fetching containers") container_set = set(container_list) async with BlobServiceClient.from_connection_string( conn_str=self.connection_string, retry_total=self.retry_count @@ -247,6 +252,7 @@ async def get_blob(self, container): Yields: dictionary: Formatted blob document """ + self._logger.info(f"Fetching blobs for '{container['name']}' container") async with ContainerClient.from_connection_string( conn_str=self.connection_string, container_name=container["name"], diff --git a/connectors/sources/confluence.py b/connectors/sources/confluence.py index 60830a9d2..46ec94262 100644 --- a/connectors/sources/confluence.py +++ b/connectors/sources/confluence.py @@ -81,6 +81,10 @@ WILDCARD = "*" +class InvalidConfluenceDataSourceTypeError(ValueError): + pass + + class ConfluenceClient: """Confluence client to handle API calls made to Confluence""" @@ -114,7 +118,7 @@ def _get_session(self): if self.session: return self.session - self._logger.debug("Creating a client session") + self._logger.debug(f"Creating a '{self.data_source_type}' client session") if self.data_source_type == CONFLUENCE_CLOUD: auth = ( self.configuration["account_email"], @@ -125,11 +129,16 @@ def _get_session(self): self.configuration["username"], self.configuration["password"], ) - else: + elif self.data_source_type == CONFLUENCE_DATA_CENTER: auth = ( self.configuration["data_center_username"], self.configuration["data_center_password"], ) + else: + msg = f"Unknown data source type '{self.data_source_type}' for Confluence connector" + self._logger.error(msg) + + raise InvalidConfluenceDataSourceTypeError(msg) basic_auth = aiohttp.BasicAuth(login=auth[0], password=auth[1]) timeout = aiohttp.ClientTimeout(total=None) # pyright: ignore @@ -795,9 +804,7 @@ async def fetch_server_space_permission(self, space_key): return {} url = URLS[SPACE_PERMISSION].format(space_key=space_key) - self._logger.debug( - f"Fetching permissions for space '{space_key} from Confluence server'" - ) + self._logger.info(f"Fetching permissions for '{space_key}' space") return await self.confluence_client.fetch_server_space_permission(url=url) async def fetch_documents(self, api_query): @@ -857,7 +864,7 @@ async def fetch_attachments( String: Download link to get the content of the attachment """ self._logger.info( - f"Fetching attachments for '{parent_name}' from '{parent_space}' space" + f"Fetching attachments for '{parent_name}' {parent_type} from '{parent_space}' space" ) async for attachment in self.confluence_client.fetch_attachments( content_id=content_id, @@ -941,7 +948,7 @@ async def download_attachment(self, url, attachment, timestamp=None, doit=False) if not self.can_file_be_downloaded(file_extension, filename, file_size): return - self._logger.info(f"Downloading content for file: {filename}") + self._logger.debug(f"Downloading content for file: {filename}") document = {"_id": attachment["_id"], "_timestamp": attachment["_timestamp"]} return await self.download_and_extract_file( document, @@ -1123,7 +1130,9 @@ async def get_docs(self, filtering=None): advanced_rules = filtering.get_advanced_rules() for query_info in advanced_rules: query = query_info.get("query") - logger.debug(f"Fetching confluence content using custom query: {query}") + self._logger.debug( + f"Fetching confluence content using custom query: {query}" + ) async for document, download_link in self.search_by_query(query): if download_link: yield document, partial( diff --git a/connectors/sources/mssql.py b/connectors/sources/mssql.py index 3a7459ac1..7799587bc 100644 --- a/connectors/sources/mssql.py +++ b/connectors/sources/mssql.py @@ -255,6 +255,12 @@ async def ping(self): async def get_tables_to_fetch(self, is_filtering=False): tables = configured_tables(self.tables) if is_wildcard(tables) or is_filtering: + msg = ( + "Fetching all tables as the configuration field 'tables' is set to '*'" + if not is_filtering + else "Fetching all tables as the advanced sync rules are enabled." + ) + self._logger.info(msg) async for row in fetch( cursor_func=partial( self.get_cursor, @@ -268,6 +274,7 @@ async def get_tables_to_fetch(self, is_filtering=False): ): yield row[0] else: + self._logger.info(f"Fetching user configured tables: {tables}") for table in tables: yield table @@ -302,9 +309,13 @@ async def get_table_primary_key(self, table): retry_count=self.retry_count, ) ] + + self._logger.debug(f"Found primary keys for '{table}' table") + return primary_keys async def get_table_last_update_time(self, table): + self._logger.debug(f"Fetching last updated time for table: {table}") [last_update_time] = await anext( fetch( cursor_func=partial( @@ -332,15 +343,20 @@ async def data_streamer(self, table=None, query=None): Yields: list: It will first yield the column names, then data in each row """ + if query is not None: + cursor_query = query + msg = f"Streaming records from database for using query: {query}" + else: + cursor_query = self.queries.table_data( + schema=self.schema, + table=table, + ) + msg = f"Streaming records from database for table: {table}" + self._logger.debug(msg) async for data in fetch( cursor_func=partial( self.get_cursor, - self.queries.table_data( - schema=self.schema, - table=table, - ) - if query is None - else query, + cursor_query, ), fetch_columns=True, fetch_size=self.fetch_size, @@ -496,6 +512,7 @@ def row2doc(self, row, doc_id, table, timestamp): return row async def get_primary_key(self, tables): + self._logger.debug(f"Extracting primary keys for tables: {tables}") primary_key_columns = [] for table in tables: primary_key_columns.extend( @@ -534,6 +551,7 @@ async def fetch_documents_from_table(self, table): Yields: Dict: Document to be indexed """ + self._logger.info(f"Fetching records for the table: {table}") try: docs_generator = self._yield_all_docs_from_tables(table=table) async for doc in docs_generator: @@ -553,6 +571,9 @@ async def fetch_documents_from_query(self, tables, query): Yields: Dict: Document to be indexed """ + self._logger.info( + f"Fetching records for {tables} tables using the custom query: {query}" + ) try: docs_generator = self._yield_docs_custom_query(tables=tables, query=query) async for doc in docs_generator: @@ -597,6 +618,7 @@ async def _yield_docs_custom_query(self, tables, query): async def _yield_all_docs_from_tables(self, table): row_count = await self.mssql_client.get_table_row_count(table=table) if row_count > 0: + self._logger.debug(f"Total '{row_count}' rows found in '{table}' table") # Query to get the table's primary key keys = await self.get_primary_key(tables=[table]) if keys: @@ -641,6 +663,9 @@ async def get_docs(self, filtering=None): """ if filtering and filtering.has_advanced_rules(): advanced_rules = filtering.get_advanced_rules() + self._logger.info( + f"Fetching records from the database using advanced sync rules: {advanced_rules}" + ) for rule in advanced_rules: query = rule.get("query") tables = rule.get("tables") @@ -661,9 +686,6 @@ async def get_docs(self, filtering=None): ) continue - self._logger.debug( - f"Found table: {table} in database: {self.database}." - ) table_count += 1 async for row in self.fetch_documents_from_table(table=table): yield row, None diff --git a/connectors/sources/oracle.py b/connectors/sources/oracle.py index 0c6413516..991229697 100644 --- a/connectors/sources/oracle.py +++ b/connectors/sources/oracle.py @@ -130,6 +130,7 @@ async def get_cursor(self, query): Returns: cursor: Synchronous cursor """ + self._logger.debug(f"Retrieving the cursor for query: {query}") try: loop = asyncio.get_running_loop() if self.connection is None: @@ -159,6 +160,9 @@ async def ping(self): async def get_tables_to_fetch(self): tables = configured_tables(self.tables) if is_wildcard(tables): + self._logger.info( + "Fetching all tables as the configuration field 'tables' is set to '*'" + ) async for row in fetch( cursor_func=partial( self.get_cursor, @@ -171,6 +175,7 @@ async def get_tables_to_fetch(self): ): yield row[0] else: + self._logger.info(f"Fetching user configured tables: {tables}") for table in tables: yield table @@ -190,6 +195,7 @@ async def get_table_row_count(self, table): return row_count async def get_table_primary_key(self, table): + self._logger.debug(f"Extracting primary keys for table: {table}") primary_keys = [ key async for [key] in fetch( @@ -204,9 +210,11 @@ async def get_table_primary_key(self, table): retry_count=self.retry_count, ) ] + self._logger.debug(f"Found primary keys for '{table}' table") return primary_keys async def get_table_last_update_time(self, table): + self._logger.debug(f"Fetching last updated time for table: {table}") [last_update_time] = await anext( fetch( cursor_func=partial( @@ -233,6 +241,7 @@ async def data_streamer(self, table): Yields: list: It will first yield the column names, then data in each row """ + self._logger.debug(f"Streaming records from database for table: {table}") async for data in fetch( cursor_func=partial( self.get_cursor, @@ -388,10 +397,12 @@ async def fetch_documents(self, table): Yields: Dict: Document to be indexed """ + self._logger.info(f"Fetching records for the table: {table}") try: row_count = await self.oracle_client.get_table_row_count(table=table) if row_count > 0: # Query to get the table's primary key + self._logger.debug(f"Total '{row_count}' rows found in '{table}' table") keys = await self.oracle_client.get_table_primary_key(table=table) keys = map_column_names(column_names=keys, tables=[table]) if keys: @@ -444,7 +455,6 @@ async def get_docs(self, filtering=None): """ table_count = 0 async for table in self.oracle_client.get_tables_to_fetch(): - self._logger.debug(f"Found table: {table} in database: {self.database}.") table_count += 1 async for row in self.fetch_documents(table=table): yield row, None diff --git a/connectors/sources/postgresql.py b/connectors/sources/postgresql.py index 19914ea1a..522985a84 100644 --- a/connectors/sources/postgresql.py +++ b/connectors/sources/postgresql.py @@ -62,7 +62,7 @@ def table_primary_key(self, **kwargs): def table_data(self, **kwargs): """Query to get the table data""" - return f'SELECT * FROM {kwargs["schema"]}."{kwargs["table"]}" ORDER BY {kwargs["columns"]} LIMIT {kwargs["limit"]} OFFSET {kwargs["offset"]}' + return f'SELECT * FROM {kwargs["schema"]}."{kwargs["table"]}" LIMIT {kwargs["limit"]} OFFSET {kwargs["offset"]}' def table_last_update_time(self, **kwargs): """Query to get the last update time of the table""" @@ -196,6 +196,7 @@ async def get_cursor(self, query): Returns: cursor: Asynchronous cursor """ + self._logger.debug(f"Retrieving the cursor for query: {query}") try: async with self.engine.connect() as connection: # pyright: ignore cursor = await connection.execute(text(query)) @@ -218,6 +219,12 @@ async def ping(self): async def get_tables_to_fetch(self, is_filtering=False): tables = configured_tables(self.tables) if is_wildcard(tables) or is_filtering: + msg = ( + "Fetching all tables as the configuration field 'tables' is set to '*'" + if not is_filtering + else "Fetching all tables as the advanced sync rules are enabled." + ) + self._logger.info(msg) async for row in fetch( cursor_func=partial( self.get_cursor, @@ -231,6 +238,7 @@ async def get_tables_to_fetch(self, is_filtering=False): ): yield row[0] else: + self._logger.info(f"Fetching user configured tables: {tables}") for table in tables: yield table @@ -266,9 +274,12 @@ async def get_table_primary_key(self, table): ) ] + self._logger.debug(f"Found primary keys for '{table}' table") + return primary_keys async def get_table_last_update_time(self, table): + self._logger.debug(f"Fetching last updated time for table: {table}") [last_update_time] = await anext( fetch( cursor_func=partial( @@ -284,9 +295,7 @@ async def get_table_last_update_time(self, table): ) return last_update_time - async def data_streamer( - self, table=None, query=None, row_count=None, order_by_columns=None - ): + async def data_streamer(self, table=None, query=None, row_count=None): """Streaming data from a table Args: @@ -299,23 +308,24 @@ async def data_streamer( Yields: list: It will first yield the column names, then data in each row """ - if query is None and row_count is not None and order_by_columns is not None: - order_by_columns_list = ",".join(order_by_columns) + if query is None and row_count is not None: + self._logger.debug(f"Streaming records from database using query: {query}") offset = 0 fetch_columns = True while True: async for data in fetch( cursor_func=partial( self.get_cursor, - self.queries.table_data( - schema=self.schema, - table=table, - columns=order_by_columns_list, - limit=FETCH_LIMIT, - offset=offset, - ) - if query is None - else query, + ( + self.queries.table_data( + schema=self.schema, + table=table, + limit=FETCH_LIMIT, + offset=offset, + ) + if query is None + else query + ), ), fetch_columns=fetch_columns, fetch_size=self.fetch_size, @@ -328,15 +338,18 @@ async def data_streamer( if row_count <= offset: return else: + self._logger.debug(f"Streaming records from database for table: {table}") async for data in fetch( cursor_func=partial( self.get_cursor, - self.queries.table_data( - schema=self.schema, - table=table, - ) - if query is None - else query, + ( + self.queries.table_data( + schema=self.schema, + table=table, + ) + if query is None + else query + ), ), fetch_columns=True, fetch_size=self.fetch_size, @@ -496,17 +509,15 @@ def row2doc(self, row, doc_id, table, timestamp): return row async def get_primary_key(self, tables): + self._logger.debug(f"Extracting primary keys for tables: {tables}") primary_key_columns = [] for table in tables: primary_key_columns.extend( await self.postgresql_client.get_table_primary_key(table) ) primary_key_columns = sorted(primary_key_columns) - return ( - map_column_names( - column_names=primary_key_columns, schema=self.schema, tables=tables - ), - primary_key_columns, + return map_column_names( + column_names=primary_key_columns, schema=self.schema, tables=tables ) async def fetch_documents_from_table(self, table): @@ -518,6 +529,7 @@ async def fetch_documents_from_table(self, table): Yields: Dict: Document to be indexed """ + self._logger.info(f"Fetching records for the table: {table}") try: docs_generator = self._yield_all_docs_from_tables(table=table) async for doc in docs_generator: @@ -537,6 +549,9 @@ async def fetch_documents_from_query(self, tables, query): Yields: Dict: Document to be indexed """ + self._logger.info( + f"Fetching records for {tables} tables using the custom query: {query}" + ) try: docs_generator = self._yield_docs_custom_query(tables=tables, query=query) async for doc in docs_generator: @@ -547,7 +562,7 @@ async def fetch_documents_from_query(self, tables, query): ) async def _yield_docs_custom_query(self, tables, query): - primary_key_columns, _ = await self.get_primary_key(tables=tables) + primary_key_columns = await self.get_primary_key(tables=tables) if not primary_key_columns: self._logger.warning( f"Skipping tables {', '.join(tables)} from database {self.database} since no primary key is associated with them. Assign primary key to the tables to index it in the next sync interval." @@ -582,7 +597,8 @@ async def _yield_all_docs_from_tables(self, table): row_count = await self.postgresql_client.get_table_row_count(table=table) if row_count > 0: # Query to get the table's primary key - keys, order_by_columns = await self.get_primary_key(tables=[table]) + self._logger.debug(f"Total '{row_count}' rows found in '{table}' table") + keys = await self.get_primary_key(tables=[table]) if keys: try: last_update_time = ( @@ -596,10 +612,7 @@ async def _yield_all_docs_from_tables(self, table): ) last_update_time = None async for row in self.yield_rows_for_query( - primary_key_columns=keys, - tables=[table], - row_count=row_count, - order_by_columns=order_by_columns, + primary_key_columns=keys, tables=[table], row_count=row_count ): doc_id = ( f"{self.database}_{self.schema}_{hash_id([table], row, keys)}" @@ -620,16 +633,11 @@ async def _yield_all_docs_from_tables(self, table): self._logger.warning(f"No rows found for {table}.") async def yield_rows_for_query( - self, - primary_key_columns, - tables, - query=None, - row_count=None, - order_by_columns=None, + self, primary_key_columns, tables, query=None, row_count=None ): if query is None: streamer = self.postgresql_client.data_streamer( - table=tables[0], row_count=row_count, order_by_columns=order_by_columns + table=tables[0], row_count=row_count ) else: streamer = self.postgresql_client.data_streamer(query=query) @@ -655,6 +663,9 @@ async def get_docs(self, filtering=None): """ if filtering and filtering.has_advanced_rules(): advanced_rules = filtering.get_advanced_rules() + self._logger.info( + f"Fetching records from the database using advanced sync rules: {advanced_rules}" + ) for rule in advanced_rules: query = rule.get("query") tables = rule.get("tables") @@ -668,9 +679,6 @@ async def get_docs(self, filtering=None): table_count = 0 async for table in self.postgresql_client.get_tables_to_fetch(): - self._logger.debug( - f"Found table: {table} in database: {self.database}." - ) table_count += 1 async for row in self.fetch_documents_from_table( table=table, diff --git a/tests/sources/test_confluence.py b/tests/sources/test_confluence.py index f7e2e46b2..3454654c0 100644 --- a/tests/sources/test_confluence.py +++ b/tests/sources/test_confluence.py @@ -24,6 +24,7 @@ CONFLUENCE_SERVER, ConfluenceClient, ConfluenceDataSource, + InvalidConfluenceDataSourceTypeError, ) from connectors.utils import ssl_context from tests.commons import AsyncIterator @@ -351,10 +352,12 @@ @asynccontextmanager -async def create_confluence_source(use_text_extraction_service=False): +async def create_confluence_source( + use_text_extraction_service=False, data_source=CONFLUENCE_SERVER +): async with create_source( ConfluenceDataSource, - data_source=CONFLUENCE_SERVER, + data_source=data_source, username="admin", password="changeme", confluence_url=HOST_URL, @@ -924,14 +927,34 @@ async def test_get_docs(spaces_patch, pages_patch, attachment_patch, content_pat @pytest.mark.asyncio -async def test_get_session(): - """Test that the instance of session returned is always the same for the datasource class.""" +@pytest.mark.parametrize( + "data_source_type", [CONFLUENCE_CLOUD, CONFLUENCE_DATA_CENTER, CONFLUENCE_SERVER] +) +async def test_get_session(data_source_type): + async with create_confluence_source(data_source=data_source_type) as source: + try: + source.confluence_client._get_session() + except Exception as e: + pytest.fail( + f"Should not raise for valid data source type '{data_source_type}'. Exception: {e}" + ) + + +@pytest.mark.asyncio +async def test_get_session_multiple_calls_return_same_instance(): async with create_confluence_source() as source: first_instance = source.confluence_client._get_session() second_instance = source.confluence_client._get_session() assert first_instance is second_instance +@pytest.mark.asyncio +async def test_get_session_raise_on_invalid_data_source_type(): + async with create_confluence_source(data_source="invalid") as source: + with pytest.raises(InvalidConfluenceDataSourceTypeError): + source.confluence_client._get_session() + + @pytest.mark.asyncio async def test_get_access_control_dls_disabled(): async with create_confluence_source() as source: