Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Outlook] Feature: Office365 multi-user support #2844

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
119 changes: 108 additions & 11 deletions connectors/sources/outlook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import asyncio
import os
from abc import ABC, abstractmethod
from copy import copy
from datetime import date
from functools import cached_property, partial
from typing import List

import aiofiles
import aiohttp
Expand Down Expand Up @@ -348,13 +350,13 @@ async def get_user_accounts(self):
yield user_account


class Office365Users:
"""Fetch users from Office365 Active Directory"""
class BaseOffice365User(ABC):
"""Abstract base class for Office 365 user management"""

def __init__(self, client_id, client_secret, tenant_id):
self.tenant_id = tenant_id
self.client_id = client_id
self.client_secret = client_secret
self.tenant_id = tenant_id

@cached_property
def _get_session(self):
Expand Down Expand Up @@ -403,6 +405,21 @@ async def _fetch_token(self):
except Exception as exception:
self._check_errors(response=exception)

@abstractmethod
async def get_users(self):
pass

@abstractmethod
async def get_user_accounts(self):
pass


class Office365Users(BaseOffice365User):
"""Fetch users from Office365 Active Directory"""

def __init__(self, client_id, client_secret, tenant_id):
super().__init__(client_id, client_secret, tenant_id)

@retryable(
retries=RETRIES,
interval=RETRY_INTERVAL,
Expand Down Expand Up @@ -456,6 +473,57 @@ async def get_user_accounts(self):
yield user_account


class MultiOffice365Users(BaseOffice365User):
"""Fetch multiple Office365 users based on a list of email addresses."""

def __init__(self, client_id, client_secret, tenant_id, client_emails: List[str]):
super().__init__(client_id, client_secret, tenant_id)
self.client_emails = client_emails

async def get_users(self):
access_token = await self._fetch_token()
for email in self.client_emails:
url = f"https://graph.microsoft.com/v1.0/users/{email}"
ilyasabdellaoui marked this conversation as resolved.
Show resolved Hide resolved
try:
async with self._get_session.get(
url=url,
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
) as response:
json_response = await response.json()
yield json_response
except Exception:
raise

async def get_user_accounts(self):
async for user in self.get_users():
mail = user.get("mail")
if mail is None:
continue

credentials = OAuth2Credentials(
client_id=self.client_id,
tenant_id=self.tenant_id,
client_secret=self.client_secret,
identity=Identity(primary_smtp_address=mail),
)
configuration = Configuration(
credentials=credentials,
auth_type=OAUTH2,
service_endpoint=EWS_ENDPOINT,
retry_policy=FaultTolerance(max_wait=120),
)
user_account = Account(
primary_smtp_address=mail,
config=configuration,
autodiscover=False,
access_type=IMPERSONATION,
)
yield user_account


class OutlookDocFormatter:
"""Format Outlook object documents to Elasticsearch document"""

Expand Down Expand Up @@ -583,6 +651,27 @@ def attachment_doc_formatter(self, attachment, attachment_type, timezone):
}


class UserFactory:
"""Factory class for creating Office365 user instances"""

@staticmethod
def create_user(configuration: dict) -> BaseOffice365User:
if configuration.get("client_emails"):
client_emails = [email.strip() for email in configuration["client_emails"].split(",")]
ilyasabdellaoui marked this conversation as resolved.
Show resolved Hide resolved
return MultiOffice365Users(
client_id=configuration["client_id"],
client_secret=configuration["client_secret"],
tenant_id=configuration["tenant_id"],
client_emails=client_emails
)
else:
return Office365Users(
client_id=configuration["client_id"],
client_secret=configuration["client_secret"],
tenant_id=configuration["tenant_id"]
)


class OutlookClient:
"""Outlook client to handle API calls made to Outlook"""

Expand All @@ -605,11 +694,7 @@ def set_logger(self, logger_):
@cached_property
def _get_user_instance(self):
if self.is_cloud:
return Office365Users(
client_id=self.configuration["client_id"],
client_secret=self.configuration["client_secret"],
tenant_id=self.configuration["tenant_id"],
)
return UserFactory.create_user(self.configuration)

return ExchangeUsers(
ad_server=self.configuration["active_directory_server"],
Expand Down Expand Up @@ -666,9 +751,12 @@ async def get_tasks(self, account):
yield task

async def get_contacts(self, account):
folder = account.root / "Top of Information Store" / "Contacts"
for contact in await asyncio.to_thread(folder.all().only, *CONTACT_FIELDS):
yield contact
try:
folder = account.root / "Top of Information Store" / "Contacts"
for contact in await asyncio.to_thread(folder.all().only, *CONTACT_FIELDS):
yield contact
except Exception:
raise


class OutlookDataSource(BaseDataSource):
Expand Down Expand Up @@ -735,6 +823,13 @@ def get_default_configuration(cls):
"sensitive": True,
"type": "str",
},
"client_emails": {
"depends_on": [{"field": "data_source", "value": OUTLOOK_CLOUD}],
"label": "Client Email Addresses (comma-separated)",
ilyasabdellaoui marked this conversation as resolved.
Show resolved Hide resolved
"order": 5,
ilyasabdellaoui marked this conversation as resolved.
Show resolved Hide resolved
"required": False,
"type": "str",
ilyasabdellaoui marked this conversation as resolved.
Show resolved Hide resolved
},
"exchange_server": {
"depends_on": [{"field": "data_source", "value": OUTLOOK_SERVER}],
"label": "Exchange Server",
Expand Down Expand Up @@ -1072,9 +1167,11 @@ async def get_docs(self, filtering=None):
dictionary: dictionary containing meta-data of the files.
"""
async for account in self.client._get_user_instance.get_user_accounts():
self._logger.debug(f"Processing account: {account}")
timezone = account.default_timezone or DEFAULT_TIMEZONE

async for mail in self._fetch_mails(account=account, timezone=timezone):
self._logger.debug(f"Fetched mail: {mail}")
yield mail

async for contact in self._fetch_contacts(
Expand Down
99 changes: 81 additions & 18 deletions tests/sources/test_outlook.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ async def create_outlook_source(
tenant_id="foo",
client_id="bar",
client_secret="faa",
client_emails=None,
exchange_server="127.0.0.1",
active_directory_server="127.0.0.1",
username="fee",
Expand All @@ -383,12 +384,16 @@ async def create_outlook_source(
ssl_ca="",
use_text_extraction_service=False,
):
if client_emails is None:
client_emails = ""

async with create_source(
OutlookDataSource,
data_source=data_source,
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
client_emails=client_emails,
exchange_server=exchange_server,
active_directory_server=active_directory_server,
username=username,
Expand All @@ -415,26 +420,36 @@ def get_stream_reader():
return async_mock


def side_effect_function(url, headers):
def side_effect_function(client_emails=None):
"""Dynamically changing return values for API calls
Args:
url, ssl: Params required for get call
client_emails: Optional string of comma-separated email addresses
"""
if url == "https://graph.microsoft.com/v1.0/users?$top=999":
return get_json_mock(
mock_response={
"@odata.nextLink": "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token",
"value": [{"mail": "[email protected]"}],
},
status=200,
)
elif (
url
== "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token"
):
return get_json_mock(
mock_response={"value": [{"mail": "[email protected]"}]}, status=200
)
def inner(url, headers):
if client_emails:
emails = [email.strip() for email in client_emails.split(",")]
for email in emails:
if url == f"https://graph.microsoft.com/v1.0/users/{email}":
users_response = {"value": [{"mail": email}]}
return get_json_mock(mock_response=users_response, status=200)
elif url == "https://graph.microsoft.com/v1.0/users?$top=999":
return get_json_mock(
mock_response={
"@odata.nextLink": "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token",
"value": [{"mail": "[email protected]"}],
},
status=200,
)
elif (
url
== "https://graph.microsoft.com/v1.0/users?$top=999&$skipToken=fake-skip-token"
):
return get_json_mock(
mock_response={"value": [{"mail": "[email protected]"}]}, status=200
)

return inner


@pytest.mark.asyncio
Expand All @@ -459,6 +474,7 @@ def side_effect_function(url, headers):
"tenant_id": "foo",
"client_id": "bar",
"client_secret": "",
"client_emails": None,
}
),
],
Expand Down Expand Up @@ -497,6 +513,17 @@ async def test_validate_configuration_with_invalid_dependency_fields_raises_erro
"tenant_id": "foo",
"client_id": "bar",
"client_secret": "foo.bar",
"client_emails": None
}
),
(
# Outlook Cloud with non-blank dependent fields & client_emails provided
{
"data_source": OUTLOOK_CLOUD,
"tenant_id": "foo",
"client_id": "bar",
"client_secret": "foo.bar",
"client_emails": "[email protected]"
}
),
],
Expand Down Expand Up @@ -552,7 +579,7 @@ async def test_ping_for_cloud():
):
with mock.patch(
"aiohttp.ClientSession.get",
side_effect=side_effect_function,
side_effect=side_effect_function(),
):
await source.ping()

Expand Down Expand Up @@ -597,13 +624,49 @@ async def test_get_users_for_cloud():
):
with mock.patch(
"aiohttp.ClientSession.get",
side_effect=side_effect_function,
side_effect=side_effect_function(),
):
async for response in source.client._get_user_instance.get_users():
user_mails = [user["mail"] for user in response["value"]]
users.extend(user_mails)
assert users == ["[email protected]", "[email protected]"]

client_emails = "[email protected]"
async with create_outlook_source(client_emails=client_emails) as source:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest making these their own test functions. That way, any failures are more easily identifyable, and they can have more focused test names to help the reader understand what's being tested.

users = []
with mock.patch(
"aiohttp.ClientSession.post",
return_value=get_json_mock(
mock_response={"access_token": "fake-token"}, status=200
),
):
with mock.patch(
"aiohttp.ClientSession.get",
side_effect=side_effect_function(client_emails),
):
async for response in source.client._get_user_instance.get_users():
user_mails = [user["mail"] for user in response["value"]]
users.extend(user_mails)
assert users == ["[email protected]"]

client_emails = "[email protected], [email protected]"
async with create_outlook_source(client_emails=client_emails) as source:
users = []
with mock.patch(
"aiohttp.ClientSession.post",
return_value=get_json_mock(
mock_response={"access_token": "fake-token"}, status=200
),
):
with mock.patch(
"aiohttp.ClientSession.get",
side_effect=side_effect_function(client_emails),
):
async for response in source.client._get_user_instance.get_users():
user_mails = [user["mail"] for user in response["value"]]
users.extend(user_mails)
assert set(users) == {"[email protected]", "[email protected]"}


@pytest.mark.asyncio
@patch("connectors.sources.outlook.Connection")
Expand Down