Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MultiDomainBasicAuth #12934

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
13 changes: 7 additions & 6 deletions src/pip/_internal/cli/index_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _build_session(
retries: Optional[int] = None,
timeout: Optional[int] = None,
) -> "PipSession":
from pip._internal.network.session import PipSession
from pip._internal.network.session import MultiDomainAuthSettings, PipSession

cache_dir = options.cache_dir
assert not cache_dir or os.path.isabs(cache_dir)
Expand All @@ -100,8 +100,13 @@ def _build_session(
cache=os.path.join(cache_dir, "http-v2") if cache_dir else None,
retries=retries if retries is not None else options.retries,
trusted_hosts=options.trusted_hosts,
index_urls=self._get_index_urls(options),
ssl_context=ssl_context,
multi_domain_auth_settings=MultiDomainAuthSettings(
index_urls=self._get_index_urls(options),
# Determine if we can prompt the user for authentication or not
prompting=not options.no_input,
keyring_provider=options.keyring_provider,
),
)

# Handle custom ca-bundles from the user
Expand All @@ -124,10 +129,6 @@ def _build_session(
}
session.trust_env = False

# Determine if we can prompt the user for authentication or not
session.auth.prompting = not options.no_input
session.auth.keyring_provider = options.keyring_provider

return session


Expand Down
142 changes: 69 additions & 73 deletions src/pip/_internal/network/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import shutil
import subprocess
import sysconfig
import typing
import urllib.parse
from abc import ABC, abstractmethod
from functools import lru_cache
from os.path import commonprefix
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
Expand All @@ -33,8 +31,6 @@

logger = getLogger(__name__)

KEYRING_DISABLED = False


class Credentials(NamedTuple):
url: str
Expand Down Expand Up @@ -159,13 +155,52 @@ def _set_password(self, service_name: str, username: str, password: str) -> None
return None


@lru_cache(maxsize=None)
def which_skip_scripts(command: str) -> Optional[str]:
"""
Find the given command, but skip past the "scripts" directory. This is useful
if you want to find a system install of some command, and not have it shadowed
by a virtualenv that happens to provide that very command.
"""
path = shutil.which(command)
if path and path.startswith(sysconfig.get_path("scripts")):
# all code within this function is stolen from shutil.which implementation
def PATH_as_shutil_which_determines_it() -> str:
path = os.environ.get("PATH", None)
if path is None:
try:
path = os.confstr("CS_PATH")
except (AttributeError, ValueError):
# os.confstr() or CS_PATH is not available
path = None

if path is None:
path = os.defpath
# bpo-35755: Don't use os.defpath if the PATH environment variable is
# set to an empty string

return path

scripts = Path(sysconfig.get_path("scripts"))

paths = []
for path in PATH_as_shutil_which_determines_it().split(os.pathsep):
p = Path(path)
try:
if not p.samefile(scripts):
paths.append(path)
except FileNotFoundError:
pass

path = os.pathsep.join(paths)

path = shutil.which(command, path=path)

return path


def get_keyring_provider(provider: str) -> KeyRingBaseProvider:
logger.verbose("Keyring provider requested: %s", provider)

# keyring has previously failed and been disabled
if KEYRING_DISABLED:
provider = "disabled"
if provider in ["import", "auto"]:
try:
impl = KeyRingPythonProvider()
Expand All @@ -181,38 +216,7 @@ def get_keyring_provider(provider: str) -> KeyRingBaseProvider:
msg = msg + ", trying to find a keyring executable as a fallback"
logger.warning(msg, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
if provider in ["subprocess", "auto"]:
cli = shutil.which("keyring")
if cli and cli.startswith(sysconfig.get_path("scripts")):
# all code within this function is stolen from shutil.which implementation
@typing.no_type_check
def PATH_as_shutil_which_determines_it() -> str:
path = os.environ.get("PATH", None)
if path is None:
try:
path = os.confstr("CS_PATH")
except (AttributeError, ValueError):
# os.confstr() or CS_PATH is not available
path = os.defpath
# bpo-35755: Don't use os.defpath if the PATH environment variable is
# set to an empty string

return path

scripts = Path(sysconfig.get_path("scripts"))

paths = []
for path in PATH_as_shutil_which_determines_it().split(os.pathsep):
p = Path(path)
try:
if not p.samefile(scripts):
paths.append(path)
except FileNotFoundError:
pass

path = os.pathsep.join(paths)

cli = shutil.which("keyring", path=path)

cli = which_skip_scripts("keyring")
if cli:
logger.verbose("Keyring provider set: subprocess with executable %s", cli)
return KeyRingCliProvider(cli)
Expand All @@ -228,35 +232,27 @@ def __init__(
index_urls: Optional[List[str]] = None,
keyring_provider: str = "auto",
) -> None:
self.prompting = prompting
self.index_urls = index_urls
self.keyring_provider = keyring_provider # type: ignore[assignment]
self.passwords: Dict[str, AuthInfo] = {}
self._prompting = prompting
self._index_urls = index_urls
self._keyring_provider_name = keyring_provider
self._keyring_provider = get_keyring_provider(self._keyring_provider_name)
self._passwords: Dict[str, AuthInfo] = {}
# When the user is prompted to enter credentials and keyring is
# available, we will offer to save them. If the user accepts,
# this value is set to the credentials they entered. After the
# request authenticates, the caller should call
# ``save_credentials`` to save these.
self._credentials_to_save: Optional[Credentials] = None

@property
def keyring_provider(self) -> KeyRingBaseProvider:
return get_keyring_provider(self._keyring_provider)

@keyring_provider.setter
def keyring_provider(self, provider: str) -> None:
# The free function get_keyring_provider has been decorated with
# functools.cache. If an exception occurs in get_keyring_auth that
# cache will be cleared and keyring disabled, take that into account
# if you want to remove this indirection.
self._keyring_provider = provider

@property
def use_keyring(self) -> bool:
# We won't use keyring when --no-input is passed unless
# a specific provider is requested because it might require
# user interaction
return self.prompting or self._keyring_provider not in ["auto", "disabled"]
return self._prompting or self._keyring_provider_name not in [
"auto",
"disabled",
]

def _get_keyring_auth(
self,
Expand All @@ -269,7 +265,7 @@ def _get_keyring_auth(
return None

try:
return self.keyring_provider.get_auth_info(url, username)
return self._keyring_provider.get_auth_info(url, username)
except Exception as exc:
# Log the full exception (with stacktrace) at debug, so it'll only
# show up when running in verbose mode.
Expand All @@ -279,9 +275,9 @@ def _get_keyring_auth(
"Keyring is skipped due to an exception: %s",
str(exc),
)
global KEYRING_DISABLED
KEYRING_DISABLED = True
get_keyring_provider.cache_clear()
# Disable keyring.
self._keyring_provider_name = "disabled"
self._keyring_provider = get_keyring_provider(self._keyring_provider_name)
return None

def _get_index_url(self, url: str) -> Optional[str]:
Expand All @@ -297,15 +293,15 @@ def _get_index_url(self, url: str) -> Optional[str]:
Returns None if no matching index was found, or if --no-index
was specified by the user.
"""
if not url or not self.index_urls:
if not url or not self._index_urls:
return None

url = remove_auth_from_url(url).rstrip("/") + "/"
parsed_url = urllib.parse.urlsplit(url)

candidates = []

for index in self.index_urls:
for index in self._index_urls:
index = index.rstrip("/") + "/"
parsed_index = urllib.parse.urlsplit(remove_auth_from_url(index))
if parsed_url == parsed_index:
Expand Down Expand Up @@ -410,8 +406,8 @@ def _get_url_and_credentials(
# Do this if either the username or the password is missing.
# This accounts for the situation in which the user has specified
# the username in the index url, but the password comes from keyring.
if (username is None or password is None) and netloc in self.passwords:
un, pw = self.passwords[netloc]
if (username is None or password is None) and netloc in self._passwords:
un, pw = self._passwords[netloc]
# It is possible that the cached credentials are for a different username,
# in which case the cache should be ignored.
if username is None or username == un:
Expand All @@ -426,7 +422,7 @@ def _get_url_and_credentials(
password = password or ""

# Store any acquired credentials.
self.passwords[netloc] = (username, password)
self._passwords[netloc] = (username, password)

assert (
# Credentials were found
Expand Down Expand Up @@ -457,7 +453,7 @@ def __call__(self, req: Request) -> Request:
def _prompt_for_password(
self, netloc: str
) -> Tuple[Optional[str], Optional[str], bool]:
username = ask_input(f"User for {netloc}: ") if self.prompting else None
username = ask_input(f"User for {netloc}: ") if self._prompting else None
if not username:
return None, None, False
if self.use_keyring:
Expand All @@ -470,9 +466,9 @@ def _prompt_for_password(
# Factored out to allow for easy patching in tests
def _should_save_password_to_keyring(self) -> bool:
if (
not self.prompting
not self._prompting
or not self.use_keyring
or not self.keyring_provider.has_keyring
or not self._keyring_provider.has_keyring
):
return False
return ask("Save credentials to keyring [y/N]: ", ["y", "n"]) == "y"
Expand All @@ -494,7 +490,7 @@ def handle_401(self, resp: Response, **kwargs: Any) -> Response:
)

# We are not able to prompt the user so simply return the response
if not self.prompting and not username and not password:
if not self._prompting and not username and not password:
return resp

parsed = urllib.parse.urlparse(resp.url)
Expand All @@ -507,7 +503,7 @@ def handle_401(self, resp: Response, **kwargs: Any) -> Response:
# Store the new username and password to use for future requests
self._credentials_to_save = None
if username is not None and password is not None:
self.passwords[parsed.netloc] = (username, password)
self._passwords[parsed.netloc] = (username, password)

# Prompt to save the password to keyring
if save and self._should_save_password_to_keyring():
Expand Down Expand Up @@ -551,15 +547,15 @@ def warn_on_401(self, resp: Response, **kwargs: Any) -> None:
def save_credentials(self, resp: Response, **kwargs: Any) -> None:
"""Response callback to save credentials on success."""
assert (
self.keyring_provider.has_keyring
self._keyring_provider.has_keyring
), "should never reach here without keyring"

creds = self._credentials_to_save
self._credentials_to_save = None
if creds and resp.status_code < 400:
try:
logger.info("Saving credentials to keyring")
self.keyring_provider.save_auth_info(
self._keyring_provider.save_auth_info(
creds.url, creds.username, creds.password
)
except Exception:
Expand Down
18 changes: 16 additions & 2 deletions src/pip/_internal/network/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import urllib.parse
import warnings
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -317,6 +318,13 @@ def cert_verify(
super().cert_verify(conn=conn, url=url, verify=False, cert=cert)


@dataclass(frozen=True)
class MultiDomainAuthSettings:
prompting: bool = True
index_urls: Optional[List[str]] = None
keyring_provider: str = "auto"


class PipSession(requests.Session):
timeout: Optional[int] = None

Expand All @@ -326,8 +334,8 @@ def __init__(
retries: int = 0,
cache: Optional[str] = None,
trusted_hosts: Sequence[str] = (),
index_urls: Optional[List[str]] = None,
ssl_context: Optional["SSLContext"] = None,
multi_domain_auth_settings: Optional[MultiDomainAuthSettings] = None,
**kwargs: Any,
) -> None:
"""
Expand All @@ -344,7 +352,13 @@ def __init__(
self.headers["User-Agent"] = user_agent()

# Attach our Authentication handler to the session
self.auth = MultiDomainBasicAuth(index_urls=index_urls)
if multi_domain_auth_settings is None:
multi_domain_auth_settings = MultiDomainAuthSettings()
self.auth = MultiDomainBasicAuth(
prompting=multi_domain_auth_settings.prompting,
index_urls=multi_domain_auth_settings.index_urls,
keyring_provider=multi_domain_auth_settings.keyring_provider,
)

# Create our urllib3.Retry instance which will allow us to customize
# how we handle retries.
Expand Down
14 changes: 12 additions & 2 deletions tests/unit/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,12 +1012,17 @@ def test_collect_page_sources(
# Check that index URLs are marked as *un*cacheable.
assert not pages[0].link.cache_link_parsing

# Skip past a couple of log messages about keyring.
record_tuples = caplog.record_tuples
assert record_tuples.pop(0)[2].startswith("Keyring provider")
assert record_tuples.pop(0)[2].startswith("Keyring provider")

expected_message = dedent(
"""\
1 location(s) to search for versions of twine:
* https://pypi.org/simple/twine/"""
)
assert caplog.record_tuples == [
assert record_tuples == [
("pip._internal.index.collector", logging.DEBUG, expected_message),
]

Expand Down Expand Up @@ -1058,12 +1063,17 @@ def test_collect_file_sources(
assert len(files) > 0
check_links_include(files, names=["singlemodule-0.0.1.tar.gz"])

# Skip past a couple of log messages about keyring.
record_tuples = caplog.record_tuples
assert record_tuples.pop(0)[2].startswith("Keyring provider")
assert record_tuples.pop(0)[2].startswith("Keyring provider")

expected_message = dedent(
"""\
1 location(s) to search for versions of singlemodule:
* https://pypi.org/simple/singlemodule/"""
)
assert caplog.record_tuples == [
assert record_tuples == [
("pip._internal.index.collector", logging.DEBUG, expected_message),
]

Expand Down
Loading