airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth

  1#
  2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
  3#
  4
  5import logging
  6from abc import abstractmethod
  7from datetime import timedelta
  8from json import JSONDecodeError
  9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
 10
 11import backoff
 12import requests
 13from requests.auth import AuthBase
 14
 15from airbyte_cdk.models import FailureType, Level
 16from airbyte_cdk.sources.http_logger import format_http_message
 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
 18from airbyte_cdk.utils import AirbyteTracedException
 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets
 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse
 21
 22from ..exceptions import DefaultBackoffException
 23
 24logger = logging.getLogger("airbyte")
 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()
 26
 27
 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
 29    """
 30    Raised when the max level of recursion is reached, when trying to
 31    find-and-get the target key, during the `_make_handled_request`
 32    """
 33
 34
 35class AbstractOauth2Authenticator(AuthBase):
 36    """
 37    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 38    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 39    delegating that behavior to the classes implementing the interface.
 40    """
 41
 42    _NO_STREAM_NAME = None
 43
 44    def __init__(
 45        self,
 46        refresh_token_error_status_codes: Tuple[int, ...] = (),
 47        refresh_token_error_key: str = "",
 48        refresh_token_error_values: Tuple[str, ...] = (),
 49    ) -> None:
 50        """
 51        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 52        then http errors with such params will be wrapped in AirbyteTracedException.
 53        """
 54        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 55        self._refresh_token_error_key = refresh_token_error_key
 56        self._refresh_token_error_values = refresh_token_error_values
 57
 58    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 59        """Attach the HTTP headers required to authenticate on the HTTP request"""
 60        request.headers.update(self.get_auth_header())
 61        return request
 62
 63    @property
 64    def _is_access_token_flow(self) -> bool:
 65        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 66
 67    @property
 68    def token_expiry_is_time_of_expiration(self) -> bool:
 69        """
 70        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 71        """
 72
 73        return False
 74
 75    @property
 76    def token_expiry_date_format(self) -> Optional[str]:
 77        """
 78        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 79        """
 80
 81        return None
 82
 83    def get_auth_header(self) -> Mapping[str, Any]:
 84        """HTTP header to set on the requests"""
 85        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 86        return {"Authorization": f"Bearer {token}"}
 87
 88    def get_access_token(self) -> str:
 89        """Returns the access token"""
 90        if self.token_has_expired():
 91            token, expires_in = self.refresh_access_token()
 92            self.access_token = token
 93            self.set_token_expiry_date(expires_in)
 94
 95        return self.access_token
 96
 97    def token_has_expired(self) -> bool:
 98        """Returns True if the token is expired"""
 99        return ab_datetime_now() > self.get_token_expiry_date()
100
101    def build_refresh_request_body(self) -> Mapping[str, Any]:
102        """
103        Returns the request body to set on the refresh request
104
105        Override to define additional parameters
106        """
107        payload: MutableMapping[str, Any] = {
108            self.get_grant_type_name(): self.get_grant_type(),
109            self.get_client_id_name(): self.get_client_id(),
110            self.get_client_secret_name(): self.get_client_secret(),
111            self.get_refresh_token_name(): self.get_refresh_token(),
112        }
113
114        if self.get_scopes():
115            payload["scopes"] = self.get_scopes()
116
117        if self.get_refresh_request_body():
118            for key, val in self.get_refresh_request_body().items():
119                # We defer to existing oauth constructs over custom configured fields
120                if key not in payload:
121                    payload[key] = val
122
123        return payload
124
125    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
126        """
127        Returns the request headers to set on the refresh request
128
129        """
130        headers = self.get_refresh_request_headers()
131        return headers if headers else None
132
133    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
134        """
135        Returns the refresh token and its expiration datetime
136
137        :return: a tuple of (access_token, token_lifespan)
138        """
139        response_json = self._make_handled_request()
140        self._ensure_access_token_in_response(response_json)
141
142        return (
143            self._extract_access_token(response_json),
144            self._extract_token_expiry_date(response_json),
145        )
146
147    # ----------------
148    # PRIVATE METHODS
149    # ----------------
150
151    def _default_token_expiry_date(self) -> AirbyteDateTime:
152        """
153        Returns the default token expiry date
154        """
155        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
156        default_token_expiry_duration_hours = 1  # 1 hour
157        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
158
159    def _wrap_refresh_token_exception(
160        self, exception: requests.exceptions.RequestException
161    ) -> bool:
162        """
163        Wraps and handles exceptions that occur during the refresh token process.
164
165        This method checks if the provided exception is related to a refresh token error
166        by examining the response status code and specific error content.
167
168        Args:
169            exception (requests.exceptions.RequestException): The exception raised during the request.
170
171        Returns:
172            bool: True if the exception is related to a refresh token error, False otherwise.
173        """
174        try:
175            if exception.response is not None:
176                exception_content = exception.response.json()
177            else:
178                return False
179        except JSONDecodeError:
180            return False
181        return (
182            exception.response.status_code in self._refresh_token_error_status_codes
183            and exception_content.get(self._refresh_token_error_key)
184            in self._refresh_token_error_values
185        )
186
187    @backoff.on_exception(
188        backoff.expo,
189        DefaultBackoffException,
190        on_backoff=lambda details: logger.info(
191            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
192        ),
193        max_time=300,
194    )
195    def _make_handled_request(self) -> Any:
196        """
197        Makes a handled HTTP request to refresh an OAuth token.
198
199        This method sends a POST request to the token refresh endpoint with the necessary
200        headers and body to obtain a new access token. It handles various exceptions that
201        may occur during the request and logs the response for troubleshooting purposes.
202
203        Returns:
204            Mapping[str, Any]: The JSON response from the token refresh endpoint.
205
206        Raises:
207            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
208                                     or any 5xx server error.
209            AirbyteTracedException: If the refresh token is invalid or expired, prompting
210                                    re-authentication.
211            Exception: For any other exceptions that occur during the request.
212        """
213        try:
214            response = requests.request(
215                method="POST",
216                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
217                data=self.build_refresh_request_body(),
218                headers=self.build_refresh_request_headers(),
219            )
220            # log the response even if the request failed for troubleshooting purposes
221            self._log_response(response)
222            response.raise_for_status()
223            return response.json()
224        except requests.exceptions.RequestException as e:
225            if e.response is not None:
226                if e.response.status_code == 429 or e.response.status_code >= 500:
227                    raise DefaultBackoffException(request=e.response.request, response=e.response)
228            if self._wrap_refresh_token_exception(e):
229                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
230                raise AirbyteTracedException(
231                    internal_message=message, message=message, failure_type=FailureType.config_error
232                )
233            raise
234        except Exception as e:
235            raise Exception(f"Error while refreshing access token: {e}") from e
236
237    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
238        """
239        Ensures that the access token is present in the response data.
240
241        This method attempts to extract the access token from the provided response data.
242        If the access token is not found, it raises an exception indicating that the token
243        refresh API response was missing the access token. If the access token is found,
244        it adds the token to the list of secrets to ensure it is replaced before logging
245        the response.
246
247        Args:
248            response_data (Mapping[str, Any]): The response data from which to extract the access token.
249
250        Raises:
251            Exception: If the access token is not found in the response data.
252            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
253        """
254        try:
255            access_key = self._extract_access_token(response_data)
256            if not access_key:
257                raise Exception(
258                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
259                )
260            # Add the access token to the list of secrets so it is replaced before logging the response
261            # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
262            add_to_secrets(access_key)
263        except ResponseKeysMaxRecurtionReached as e:
264            raise e
265
266    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
267        """
268        Parse a string or integer token expiration date into a datetime object
269
270        :return: expiration datetime
271        """
272        if self.token_expiry_is_time_of_expiration:
273            if not self.token_expiry_date_format:
274                raise ValueError(
275                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
276                )
277            try:
278                return ab_datetime_parse(str(value))
279            except ValueError as e:
280                raise ValueError(f"Invalid token expiry date format: {e}")
281        else:
282            try:
283                # Only accept numeric values (as int/float/string) when no format specified
284                seconds = int(float(str(value)))
285                return ab_datetime_now() + timedelta(seconds=seconds)
286            except (ValueError, TypeError):
287                raise ValueError(
288                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
289                )
290
291    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
292        """
293        Extracts the access token from the given response data.
294
295        Args:
296            response_data (Mapping[str, Any]): The response data from which to extract the access token.
297
298        Returns:
299            str: The extracted access token.
300        """
301        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
302
303    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
304        """
305        Extracts the refresh token from the given response data.
306
307        Args:
308            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
309
310        Returns:
311            str: The extracted refresh token.
312        """
313        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
314
315    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
316        """
317        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
318
319        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
320
321        Args:
322            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
323
324        Returns:
325            The extracted token_expiry_date or None if not found.
326        """
327        expires_in = self._find_and_get_value_from_response(
328            response_data, self.get_expires_in_name()
329        )
330        if expires_in is not None:
331            return self._parse_token_expiration_date(expires_in)
332
333        # expires_in is None
334        existing_expiry_date = self.get_token_expiry_date()
335        if existing_expiry_date and not self.token_has_expired():
336            return existing_expiry_date
337
338        return self._default_token_expiry_date()
339
340    def _find_and_get_value_from_response(
341        self,
342        response_data: Mapping[str, Any],
343        key_name: str,
344        max_depth: int = 5,
345        current_depth: int = 0,
346    ) -> Any:
347        """
348        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
349
350        Args:
351            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
352            key_name (str): The key to search for in the response data.
353            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
354            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
355
356        Returns:
357            Any: The value associated with the specified key if found, otherwise None.
358
359        Raises:
360            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
361        """
362        if current_depth > max_depth:
363            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
364            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
365            raise ResponseKeysMaxRecurtionReached(
366                internal_message=message, message=message, failure_type=FailureType.config_error
367            )
368
369        if isinstance(response_data, dict):
370            # get from the root level
371            if key_name in response_data:
372                return response_data[key_name]
373
374            # get from the nested object
375            for _, value in response_data.items():
376                result = self._find_and_get_value_from_response(
377                    value, key_name, max_depth, current_depth + 1
378                )
379                if result is not None:
380                    return result
381
382        # get from the nested array object
383        elif isinstance(response_data, list):
384            for item in response_data:
385                result = self._find_and_get_value_from_response(
386                    item, key_name, max_depth, current_depth + 1
387                )
388                if result is not None:
389                    return result
390
391        return None
392
393    @property
394    def _message_repository(self) -> Optional[MessageRepository]:
395        """
396        The implementation can define a message_repository if it wants debugging logs for HTTP requests
397        """
398        return _NOOP_MESSAGE_REPOSITORY
399
400    def _log_response(self, response: requests.Response) -> None:
401        """
402        Logs the HTTP response using the message repository if it is available.
403
404        Args:
405            response (requests.Response): The HTTP response to log.
406        """
407        if self._message_repository:
408            self._message_repository.log_message(
409                Level.DEBUG,
410                lambda: format_http_message(
411                    response,
412                    "Refresh token",
413                    "Obtains access token",
414                    self._NO_STREAM_NAME,
415                    is_auxiliary=True,
416                    type="AUTH",
417                ),
418            )
419
420    # ----------------
421    # ABSTR METHODS
422    # ----------------
423
424    @abstractmethod
425    def get_token_refresh_endpoint(self) -> Optional[str]:
426        """Returns the endpoint to refresh the access token"""
427
428    @abstractmethod
429    def get_client_id_name(self) -> str:
430        """The client id name to authenticate"""
431
432    @abstractmethod
433    def get_client_id(self) -> str:
434        """The client id to authenticate"""
435
436    @abstractmethod
437    def get_client_secret_name(self) -> str:
438        """The client secret name to authenticate"""
439
440    @abstractmethod
441    def get_client_secret(self) -> str:
442        """The client secret to authenticate"""
443
444    @abstractmethod
445    def get_refresh_token_name(self) -> str:
446        """The refresh token name to authenticate"""
447
448    @abstractmethod
449    def get_refresh_token(self) -> Optional[str]:
450        """The token used to refresh the access token when it expires"""
451
452    @abstractmethod
453    def get_scopes(self) -> List[str]:
454        """List of requested scopes"""
455
456    @abstractmethod
457    def get_token_expiry_date(self) -> AirbyteDateTime:
458        """Expiration date of the access token"""
459
460    @abstractmethod
461    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
462        """Setter for access token expiration date"""
463
464    @abstractmethod
465    def get_access_token_name(self) -> str:
466        """Field to extract access token from in the response"""
467
468    @abstractmethod
469    def get_expires_in_name(self) -> str:
470        """Returns the expires_in field name"""
471
472    @abstractmethod
473    def get_refresh_request_body(self) -> Mapping[str, Any]:
474        """Returns the request body to set on the refresh request"""
475
476    @abstractmethod
477    def get_refresh_request_headers(self) -> Mapping[str, Any]:
478        """Returns the request headers to set on the refresh request"""
479
480    @abstractmethod
481    def get_grant_type(self) -> str:
482        """Returns grant_type specified for requesting access_token"""
483
484    @abstractmethod
485    def get_grant_type_name(self) -> str:
486        """Returns grant_type specified name for requesting access_token"""
487
488    @property
489    @abstractmethod
490    def access_token(self) -> str:
491        """Returns the access token"""
492
493    @access_token.setter
494    @abstractmethod
495    def access_token(self, value: str) -> str:
496        """Setter for the access token"""
logger = <Logger airbyte (INFO)>
class ResponseKeysMaxRecurtionReached(airbyte_cdk.utils.traced_exception.AirbyteTracedException):
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
30    """
31    Raised when the max level of recursion is reached, when trying to
32    find-and-get the target key, during the `_make_handled_request`
33    """

Raised when the max level of recursion is reached, when trying to find-and-get the target key, during the _make_handled_request

class AbstractOauth2Authenticator(requests.auth.AuthBase):
 36class AbstractOauth2Authenticator(AuthBase):
 37    """
 38    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 39    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 40    delegating that behavior to the classes implementing the interface.
 41    """
 42
 43    _NO_STREAM_NAME = None
 44
 45    def __init__(
 46        self,
 47        refresh_token_error_status_codes: Tuple[int, ...] = (),
 48        refresh_token_error_key: str = "",
 49        refresh_token_error_values: Tuple[str, ...] = (),
 50    ) -> None:
 51        """
 52        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 53        then http errors with such params will be wrapped in AirbyteTracedException.
 54        """
 55        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 56        self._refresh_token_error_key = refresh_token_error_key
 57        self._refresh_token_error_values = refresh_token_error_values
 58
 59    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 60        """Attach the HTTP headers required to authenticate on the HTTP request"""
 61        request.headers.update(self.get_auth_header())
 62        return request
 63
 64    @property
 65    def _is_access_token_flow(self) -> bool:
 66        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 67
 68    @property
 69    def token_expiry_is_time_of_expiration(self) -> bool:
 70        """
 71        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 72        """
 73
 74        return False
 75
 76    @property
 77    def token_expiry_date_format(self) -> Optional[str]:
 78        """
 79        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 80        """
 81
 82        return None
 83
 84    def get_auth_header(self) -> Mapping[str, Any]:
 85        """HTTP header to set on the requests"""
 86        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 87        return {"Authorization": f"Bearer {token}"}
 88
 89    def get_access_token(self) -> str:
 90        """Returns the access token"""
 91        if self.token_has_expired():
 92            token, expires_in = self.refresh_access_token()
 93            self.access_token = token
 94            self.set_token_expiry_date(expires_in)
 95
 96        return self.access_token
 97
 98    def token_has_expired(self) -> bool:
 99        """Returns True if the token is expired"""
100        return ab_datetime_now() > self.get_token_expiry_date()
101
102    def build_refresh_request_body(self) -> Mapping[str, Any]:
103        """
104        Returns the request body to set on the refresh request
105
106        Override to define additional parameters
107        """
108        payload: MutableMapping[str, Any] = {
109            self.get_grant_type_name(): self.get_grant_type(),
110            self.get_client_id_name(): self.get_client_id(),
111            self.get_client_secret_name(): self.get_client_secret(),
112            self.get_refresh_token_name(): self.get_refresh_token(),
113        }
114
115        if self.get_scopes():
116            payload["scopes"] = self.get_scopes()
117
118        if self.get_refresh_request_body():
119            for key, val in self.get_refresh_request_body().items():
120                # We defer to existing oauth constructs over custom configured fields
121                if key not in payload:
122                    payload[key] = val
123
124        return payload
125
126    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
127        """
128        Returns the request headers to set on the refresh request
129
130        """
131        headers = self.get_refresh_request_headers()
132        return headers if headers else None
133
134    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
135        """
136        Returns the refresh token and its expiration datetime
137
138        :return: a tuple of (access_token, token_lifespan)
139        """
140        response_json = self._make_handled_request()
141        self._ensure_access_token_in_response(response_json)
142
143        return (
144            self._extract_access_token(response_json),
145            self._extract_token_expiry_date(response_json),
146        )
147
148    # ----------------
149    # PRIVATE METHODS
150    # ----------------
151
152    def _default_token_expiry_date(self) -> AirbyteDateTime:
153        """
154        Returns the default token expiry date
155        """
156        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
157        default_token_expiry_duration_hours = 1  # 1 hour
158        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
159
160    def _wrap_refresh_token_exception(
161        self, exception: requests.exceptions.RequestException
162    ) -> bool:
163        """
164        Wraps and handles exceptions that occur during the refresh token process.
165
166        This method checks if the provided exception is related to a refresh token error
167        by examining the response status code and specific error content.
168
169        Args:
170            exception (requests.exceptions.RequestException): The exception raised during the request.
171
172        Returns:
173            bool: True if the exception is related to a refresh token error, False otherwise.
174        """
175        try:
176            if exception.response is not None:
177                exception_content = exception.response.json()
178            else:
179                return False
180        except JSONDecodeError:
181            return False
182        return (
183            exception.response.status_code in self._refresh_token_error_status_codes
184            and exception_content.get(self._refresh_token_error_key)
185            in self._refresh_token_error_values
186        )
187
188    @backoff.on_exception(
189        backoff.expo,
190        DefaultBackoffException,
191        on_backoff=lambda details: logger.info(
192            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
193        ),
194        max_time=300,
195    )
196    def _make_handled_request(self) -> Any:
197        """
198        Makes a handled HTTP request to refresh an OAuth token.
199
200        This method sends a POST request to the token refresh endpoint with the necessary
201        headers and body to obtain a new access token. It handles various exceptions that
202        may occur during the request and logs the response for troubleshooting purposes.
203
204        Returns:
205            Mapping[str, Any]: The JSON response from the token refresh endpoint.
206
207        Raises:
208            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
209                                     or any 5xx server error.
210            AirbyteTracedException: If the refresh token is invalid or expired, prompting
211                                    re-authentication.
212            Exception: For any other exceptions that occur during the request.
213        """
214        try:
215            response = requests.request(
216                method="POST",
217                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
218                data=self.build_refresh_request_body(),
219                headers=self.build_refresh_request_headers(),
220            )
221            # log the response even if the request failed for troubleshooting purposes
222            self._log_response(response)
223            response.raise_for_status()
224            return response.json()
225        except requests.exceptions.RequestException as e:
226            if e.response is not None:
227                if e.response.status_code == 429 or e.response.status_code >= 500:
228                    raise DefaultBackoffException(request=e.response.request, response=e.response)
229            if self._wrap_refresh_token_exception(e):
230                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
231                raise AirbyteTracedException(
232                    internal_message=message, message=message, failure_type=FailureType.config_error
233                )
234            raise
235        except Exception as e:
236            raise Exception(f"Error while refreshing access token: {e}") from e
237
238    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
239        """
240        Ensures that the access token is present in the response data.
241
242        This method attempts to extract the access token from the provided response data.
243        If the access token is not found, it raises an exception indicating that the token
244        refresh API response was missing the access token. If the access token is found,
245        it adds the token to the list of secrets to ensure it is replaced before logging
246        the response.
247
248        Args:
249            response_data (Mapping[str, Any]): The response data from which to extract the access token.
250
251        Raises:
252            Exception: If the access token is not found in the response data.
253            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
254        """
255        try:
256            access_key = self._extract_access_token(response_data)
257            if not access_key:
258                raise Exception(
259                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
260                )
261            # Add the access token to the list of secrets so it is replaced before logging the response
262            # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
263            add_to_secrets(access_key)
264        except ResponseKeysMaxRecurtionReached as e:
265            raise e
266
267    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
268        """
269        Parse a string or integer token expiration date into a datetime object
270
271        :return: expiration datetime
272        """
273        if self.token_expiry_is_time_of_expiration:
274            if not self.token_expiry_date_format:
275                raise ValueError(
276                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
277                )
278            try:
279                return ab_datetime_parse(str(value))
280            except ValueError as e:
281                raise ValueError(f"Invalid token expiry date format: {e}")
282        else:
283            try:
284                # Only accept numeric values (as int/float/string) when no format specified
285                seconds = int(float(str(value)))
286                return ab_datetime_now() + timedelta(seconds=seconds)
287            except (ValueError, TypeError):
288                raise ValueError(
289                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
290                )
291
292    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
293        """
294        Extracts the access token from the given response data.
295
296        Args:
297            response_data (Mapping[str, Any]): The response data from which to extract the access token.
298
299        Returns:
300            str: The extracted access token.
301        """
302        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
303
304    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
305        """
306        Extracts the refresh token from the given response data.
307
308        Args:
309            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
310
311        Returns:
312            str: The extracted refresh token.
313        """
314        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
315
316    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
317        """
318        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
319
320        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
321
322        Args:
323            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
324
325        Returns:
326            The extracted token_expiry_date or None if not found.
327        """
328        expires_in = self._find_and_get_value_from_response(
329            response_data, self.get_expires_in_name()
330        )
331        if expires_in is not None:
332            return self._parse_token_expiration_date(expires_in)
333
334        # expires_in is None
335        existing_expiry_date = self.get_token_expiry_date()
336        if existing_expiry_date and not self.token_has_expired():
337            return existing_expiry_date
338
339        return self._default_token_expiry_date()
340
341    def _find_and_get_value_from_response(
342        self,
343        response_data: Mapping[str, Any],
344        key_name: str,
345        max_depth: int = 5,
346        current_depth: int = 0,
347    ) -> Any:
348        """
349        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
350
351        Args:
352            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
353            key_name (str): The key to search for in the response data.
354            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
355            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
356
357        Returns:
358            Any: The value associated with the specified key if found, otherwise None.
359
360        Raises:
361            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
362        """
363        if current_depth > max_depth:
364            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
365            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
366            raise ResponseKeysMaxRecurtionReached(
367                internal_message=message, message=message, failure_type=FailureType.config_error
368            )
369
370        if isinstance(response_data, dict):
371            # get from the root level
372            if key_name in response_data:
373                return response_data[key_name]
374
375            # get from the nested object
376            for _, value in response_data.items():
377                result = self._find_and_get_value_from_response(
378                    value, key_name, max_depth, current_depth + 1
379                )
380                if result is not None:
381                    return result
382
383        # get from the nested array object
384        elif isinstance(response_data, list):
385            for item in response_data:
386                result = self._find_and_get_value_from_response(
387                    item, key_name, max_depth, current_depth + 1
388                )
389                if result is not None:
390                    return result
391
392        return None
393
394    @property
395    def _message_repository(self) -> Optional[MessageRepository]:
396        """
397        The implementation can define a message_repository if it wants debugging logs for HTTP requests
398        """
399        return _NOOP_MESSAGE_REPOSITORY
400
401    def _log_response(self, response: requests.Response) -> None:
402        """
403        Logs the HTTP response using the message repository if it is available.
404
405        Args:
406            response (requests.Response): The HTTP response to log.
407        """
408        if self._message_repository:
409            self._message_repository.log_message(
410                Level.DEBUG,
411                lambda: format_http_message(
412                    response,
413                    "Refresh token",
414                    "Obtains access token",
415                    self._NO_STREAM_NAME,
416                    is_auxiliary=True,
417                    type="AUTH",
418                ),
419            )
420
421    # ----------------
422    # ABSTR METHODS
423    # ----------------
424
425    @abstractmethod
426    def get_token_refresh_endpoint(self) -> Optional[str]:
427        """Returns the endpoint to refresh the access token"""
428
429    @abstractmethod
430    def get_client_id_name(self) -> str:
431        """The client id name to authenticate"""
432
433    @abstractmethod
434    def get_client_id(self) -> str:
435        """The client id to authenticate"""
436
437    @abstractmethod
438    def get_client_secret_name(self) -> str:
439        """The client secret name to authenticate"""
440
441    @abstractmethod
442    def get_client_secret(self) -> str:
443        """The client secret to authenticate"""
444
445    @abstractmethod
446    def get_refresh_token_name(self) -> str:
447        """The refresh token name to authenticate"""
448
449    @abstractmethod
450    def get_refresh_token(self) -> Optional[str]:
451        """The token used to refresh the access token when it expires"""
452
453    @abstractmethod
454    def get_scopes(self) -> List[str]:
455        """List of requested scopes"""
456
457    @abstractmethod
458    def get_token_expiry_date(self) -> AirbyteDateTime:
459        """Expiration date of the access token"""
460
461    @abstractmethod
462    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
463        """Setter for access token expiration date"""
464
465    @abstractmethod
466    def get_access_token_name(self) -> str:
467        """Field to extract access token from in the response"""
468
469    @abstractmethod
470    def get_expires_in_name(self) -> str:
471        """Returns the expires_in field name"""
472
473    @abstractmethod
474    def get_refresh_request_body(self) -> Mapping[str, Any]:
475        """Returns the request body to set on the refresh request"""
476
477    @abstractmethod
478    def get_refresh_request_headers(self) -> Mapping[str, Any]:
479        """Returns the request headers to set on the refresh request"""
480
481    @abstractmethod
482    def get_grant_type(self) -> str:
483        """Returns grant_type specified for requesting access_token"""
484
485    @abstractmethod
486    def get_grant_type_name(self) -> str:
487        """Returns grant_type specified name for requesting access_token"""
488
489    @property
490    @abstractmethod
491    def access_token(self) -> str:
492        """Returns the access token"""
493
494    @access_token.setter
495    @abstractmethod
496    def access_token(self, value: str) -> str:
497        """Setter for the access token"""

Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.

AbstractOauth2Authenticator( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = '', refresh_token_error_values: Tuple[str, ...] = ())
45    def __init__(
46        self,
47        refresh_token_error_status_codes: Tuple[int, ...] = (),
48        refresh_token_error_key: str = "",
49        refresh_token_error_values: Tuple[str, ...] = (),
50    ) -> None:
51        """
52        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
53        then http errors with such params will be wrapped in AirbyteTracedException.
54        """
55        self._refresh_token_error_status_codes = refresh_token_error_status_codes
56        self._refresh_token_error_key = refresh_token_error_key
57        self._refresh_token_error_values = refresh_token_error_values

If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.

token_expiry_is_time_of_expiration: bool
68    @property
69    def token_expiry_is_time_of_expiration(self) -> bool:
70        """
71        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
72        """
73
74        return False

Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.

token_expiry_date_format: Optional[str]
76    @property
77    def token_expiry_date_format(self) -> Optional[str]:
78        """
79        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
80        """
81
82        return None

Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires

def get_auth_header(self) -> Mapping[str, Any]:
84    def get_auth_header(self) -> Mapping[str, Any]:
85        """HTTP header to set on the requests"""
86        token = self.access_token if self._is_access_token_flow else self.get_access_token()
87        return {"Authorization": f"Bearer {token}"}

HTTP header to set on the requests

def get_access_token(self) -> str:
89    def get_access_token(self) -> str:
90        """Returns the access token"""
91        if self.token_has_expired():
92            token, expires_in = self.refresh_access_token()
93            self.access_token = token
94            self.set_token_expiry_date(expires_in)
95
96        return self.access_token

Returns the access token

def token_has_expired(self) -> bool:
 98    def token_has_expired(self) -> bool:
 99        """Returns True if the token is expired"""
100        return ab_datetime_now() > self.get_token_expiry_date()

Returns True if the token is expired

def build_refresh_request_body(self) -> Mapping[str, Any]:
102    def build_refresh_request_body(self) -> Mapping[str, Any]:
103        """
104        Returns the request body to set on the refresh request
105
106        Override to define additional parameters
107        """
108        payload: MutableMapping[str, Any] = {
109            self.get_grant_type_name(): self.get_grant_type(),
110            self.get_client_id_name(): self.get_client_id(),
111            self.get_client_secret_name(): self.get_client_secret(),
112            self.get_refresh_token_name(): self.get_refresh_token(),
113        }
114
115        if self.get_scopes():
116            payload["scopes"] = self.get_scopes()
117
118        if self.get_refresh_request_body():
119            for key, val in self.get_refresh_request_body().items():
120                # We defer to existing oauth constructs over custom configured fields
121                if key not in payload:
122                    payload[key] = val
123
124        return payload

Returns the request body to set on the refresh request

Override to define additional parameters

def build_refresh_request_headers(self) -> Optional[Mapping[str, Any]]:
126    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
127        """
128        Returns the request headers to set on the refresh request
129
130        """
131        headers = self.get_refresh_request_headers()
132        return headers if headers else None

Returns the request headers to set on the refresh request

def refresh_access_token(self) -> Tuple[str, airbyte_cdk.utils.datetime_helpers.AirbyteDateTime]:
134    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
135        """
136        Returns the refresh token and its expiration datetime
137
138        :return: a tuple of (access_token, token_lifespan)
139        """
140        response_json = self._make_handled_request()
141        self._ensure_access_token_in_response(response_json)
142
143        return (
144            self._extract_access_token(response_json),
145            self._extract_token_expiry_date(response_json),
146        )

Returns the refresh token and its expiration datetime

Returns

a tuple of (access_token, token_lifespan)

@abstractmethod
def get_token_refresh_endpoint(self) -> Optional[str]:
425    @abstractmethod
426    def get_token_refresh_endpoint(self) -> Optional[str]:
427        """Returns the endpoint to refresh the access token"""

Returns the endpoint to refresh the access token

@abstractmethod
def get_client_id_name(self) -> str:
429    @abstractmethod
430    def get_client_id_name(self) -> str:
431        """The client id name to authenticate"""

The client id name to authenticate

@abstractmethod
def get_client_id(self) -> str:
433    @abstractmethod
434    def get_client_id(self) -> str:
435        """The client id to authenticate"""

The client id to authenticate

@abstractmethod
def get_client_secret_name(self) -> str:
437    @abstractmethod
438    def get_client_secret_name(self) -> str:
439        """The client secret name to authenticate"""

The client secret name to authenticate

@abstractmethod
def get_client_secret(self) -> str:
441    @abstractmethod
442    def get_client_secret(self) -> str:
443        """The client secret to authenticate"""

The client secret to authenticate

@abstractmethod
def get_refresh_token_name(self) -> str:
445    @abstractmethod
446    def get_refresh_token_name(self) -> str:
447        """The refresh token name to authenticate"""

The refresh token name to authenticate

@abstractmethod
def get_refresh_token(self) -> Optional[str]:
449    @abstractmethod
450    def get_refresh_token(self) -> Optional[str]:
451        """The token used to refresh the access token when it expires"""

The token used to refresh the access token when it expires

@abstractmethod
def get_scopes(self) -> List[str]:
453    @abstractmethod
454    def get_scopes(self) -> List[str]:
455        """List of requested scopes"""

List of requested scopes

@abstractmethod
def get_token_expiry_date(self) -> airbyte_cdk.utils.datetime_helpers.AirbyteDateTime:
457    @abstractmethod
458    def get_token_expiry_date(self) -> AirbyteDateTime:
459        """Expiration date of the access token"""

Expiration date of the access token

@abstractmethod
def set_token_expiry_date(self, value: airbyte_cdk.utils.datetime_helpers.AirbyteDateTime) -> None:
461    @abstractmethod
462    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
463        """Setter for access token expiration date"""

Setter for access token expiration date

@abstractmethod
def get_access_token_name(self) -> str:
465    @abstractmethod
466    def get_access_token_name(self) -> str:
467        """Field to extract access token from in the response"""

Field to extract access token from in the response

@abstractmethod
def get_expires_in_name(self) -> str:
469    @abstractmethod
470    def get_expires_in_name(self) -> str:
471        """Returns the expires_in field name"""

Returns the expires_in field name

@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
473    @abstractmethod
474    def get_refresh_request_body(self) -> Mapping[str, Any]:
475        """Returns the request body to set on the refresh request"""

Returns the request body to set on the refresh request

@abstractmethod
def get_refresh_request_headers(self) -> Mapping[str, Any]:
477    @abstractmethod
478    def get_refresh_request_headers(self) -> Mapping[str, Any]:
479        """Returns the request headers to set on the refresh request"""

Returns the request headers to set on the refresh request

@abstractmethod
def get_grant_type(self) -> str:
481    @abstractmethod
482    def get_grant_type(self) -> str:
483        """Returns grant_type specified for requesting access_token"""

Returns grant_type specified for requesting access_token

@abstractmethod
def get_grant_type_name(self) -> str:
485    @abstractmethod
486    def get_grant_type_name(self) -> str:
487        """Returns grant_type specified name for requesting access_token"""

Returns grant_type specified name for requesting access_token

access_token: str
489    @property
490    @abstractmethod
491    def access_token(self) -> str:
492        """Returns the access token"""

Returns the access token