airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth

  1#
  2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
  3#
  4
  5import logging
  6from abc import abstractmethod
  7from datetime import timedelta
  8from json import JSONDecodeError
  9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
 10
 11import backoff
 12import requests
 13from requests.auth import AuthBase
 14
 15from airbyte_cdk.models import FailureType, Level
 16from airbyte_cdk.sources.http_logger import format_http_message
 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
 18from airbyte_cdk.utils import AirbyteTracedException
 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets
 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse
 21
 22from ..exceptions import DefaultBackoffException
 23
 24logger = logging.getLogger("airbyte")
 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()
 26
 27
 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
 29    """
 30    Raised when the max level of recursion is reached, when trying to
 31    find-and-get the target key, during the `_make_handled_request`
 32    """
 33
 34
 35class AbstractOauth2Authenticator(AuthBase):
 36    """
 37    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 38    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 39    delegating that behavior to the classes implementing the interface.
 40    """
 41
 42    _NO_STREAM_NAME = None
 43
 44    def __init__(
 45        self,
 46        refresh_token_error_status_codes: Tuple[int, ...] = (),
 47        refresh_token_error_key: str = "",
 48        refresh_token_error_values: Tuple[str, ...] = (),
 49    ) -> None:
 50        """
 51        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 52        then http errors with such params will be wrapped in AirbyteTracedException.
 53        """
 54        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 55        self._refresh_token_error_key = refresh_token_error_key
 56        self._refresh_token_error_values = refresh_token_error_values
 57
 58    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 59        """Attach the HTTP headers required to authenticate on the HTTP request"""
 60        request.headers.update(self.get_auth_header())
 61        return request
 62
 63    @property
 64    def _is_access_token_flow(self) -> bool:
 65        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 66
 67    @property
 68    def token_expiry_is_time_of_expiration(self) -> bool:
 69        """
 70        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 71        """
 72
 73        return False
 74
 75    @property
 76    def token_expiry_date_format(self) -> Optional[str]:
 77        """
 78        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 79        """
 80
 81        return None
 82
 83    def get_auth_header(self) -> Mapping[str, Any]:
 84        """HTTP header to set on the requests"""
 85        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 86        return {"Authorization": f"Bearer {token}"}
 87
 88    def get_access_token(self) -> str:
 89        """Returns the access token"""
 90        if self.token_has_expired():
 91            token, expires_in = self.refresh_access_token()
 92            self.access_token = token
 93            self.set_token_expiry_date(expires_in)
 94
 95        return self.access_token
 96
 97    def token_has_expired(self) -> bool:
 98        """Returns True if the token is expired"""
 99        return ab_datetime_now() > self.get_token_expiry_date()
100
101    def build_refresh_request_body(self) -> Mapping[str, Any]:
102        """
103        Returns the request body to set on the refresh request
104
105        Override to define additional parameters
106        """
107        payload: MutableMapping[str, Any] = {
108            self.get_grant_type_name(): self.get_grant_type(),
109            self.get_client_id_name(): self.get_client_id(),
110            self.get_client_secret_name(): self.get_client_secret(),
111            self.get_refresh_token_name(): self.get_refresh_token(),
112        }
113
114        if self.get_scopes():
115            payload["scopes"] = self.get_scopes()
116
117        if self.get_refresh_request_body():
118            for key, val in self.get_refresh_request_body().items():
119                # We defer to existing oauth constructs over custom configured fields
120                if key not in payload:
121                    payload[key] = val
122
123        return payload
124
125    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
126        """
127        Returns the request headers to set on the refresh request
128
129        """
130        headers = self.get_refresh_request_headers()
131        return headers if headers else None
132
133    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
134        """
135        Returns the refresh token and its expiration datetime
136
137        :return: a tuple of (access_token, token_lifespan)
138        """
139        response_json = self._make_handled_request()
140        self._ensure_access_token_in_response(response_json)
141
142        return (
143            self._extract_access_token(response_json),
144            self._extract_token_expiry_date(response_json),
145        )
146
147    # ----------------
148    # PRIVATE METHODS
149    # ----------------
150
151    def _default_token_expiry_date(self) -> AirbyteDateTime:
152        """
153        Returns the default token expiry date
154        """
155        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
156        default_token_expiry_duration_hours = 1  # 1 hour
157        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
158
159    def _wrap_refresh_token_exception(
160        self, exception: requests.exceptions.RequestException
161    ) -> bool:
162        """
163        Wraps and handles exceptions that occur during the refresh token process.
164
165        This method checks if the provided exception is related to a refresh token error
166        by examining the response status code and specific error content.
167
168        Args:
169            exception (requests.exceptions.RequestException): The exception raised during the request.
170
171        Returns:
172            bool: True if the exception is related to a refresh token error, False otherwise.
173        """
174        try:
175            if exception.response is not None:
176                exception_content = exception.response.json()
177            else:
178                return False
179        except JSONDecodeError:
180            return False
181        return (
182            exception.response.status_code in self._refresh_token_error_status_codes
183            and exception_content.get(self._refresh_token_error_key)
184            in self._refresh_token_error_values
185        )
186
187    @backoff.on_exception(
188        backoff.expo,
189        DefaultBackoffException,
190        on_backoff=lambda details: logger.info(
191            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
192        ),
193        max_time=300,
194    )
195    def _make_handled_request(self) -> Any:
196        """
197        Makes a handled HTTP request to refresh an OAuth token.
198
199        This method sends a POST request to the token refresh endpoint with the necessary
200        headers and body to obtain a new access token. It handles various exceptions that
201        may occur during the request and logs the response for troubleshooting purposes.
202
203        Returns:
204            Mapping[str, Any]: The JSON response from the token refresh endpoint.
205
206        Raises:
207            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
208                                     or any 5xx server error.
209            AirbyteTracedException: If the refresh token is invalid or expired, prompting
210                                    re-authentication.
211            Exception: For any other exceptions that occur during the request.
212        """
213        try:
214            response = requests.request(
215                method="POST",
216                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
217                data=self.build_refresh_request_body(),
218                headers=self.build_refresh_request_headers(),
219            )
220
221            if not response.ok:
222                # log the response even if the request failed for troubleshooting purposes
223                self._log_response(response)
224                response.raise_for_status()
225
226            response_json = response.json()
227
228            try:
229                # extract the access token and add to secrets to avoid logging the raw value
230                access_key = self._extract_access_token(response_json)
231                if access_key:
232                    add_to_secrets(access_key)
233            except ResponseKeysMaxRecurtionReached as e:
234                # could not find the access token in the response, so do nothing
235                pass
236
237            self._log_response(response)
238
239            return response_json
240        except requests.exceptions.RequestException as e:
241            if e.response is not None:
242                if e.response.status_code == 429 or e.response.status_code >= 500:
243                    raise DefaultBackoffException(
244                        request=e.response.request,
245                        response=e.response,
246                        failure_type=FailureType.transient_error,
247                    )
248            if self._wrap_refresh_token_exception(e):
249                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
250                raise AirbyteTracedException(
251                    internal_message=message, message=message, failure_type=FailureType.config_error
252                )
253            raise
254        except Exception as e:
255            raise Exception(f"Error while refreshing access token: {e}") from e
256
257    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
258        """
259        Ensures that the access token is present in the response data.
260
261        This method attempts to extract the access token from the provided response data.
262        If the access token is not found, it raises an exception indicating that the token
263        refresh API response was missing the access token.
264
265        Args:
266            response_data (Mapping[str, Any]): The response data from which to extract the access token.
267
268        Raises:
269            Exception: If the access token is not found in the response data.
270            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
271        """
272        try:
273            access_key = self._extract_access_token(response_data)
274            if not access_key:
275                raise Exception(
276                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
277                )
278        except ResponseKeysMaxRecurtionReached as e:
279            raise e
280
281    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
282        """
283        Parse a string or integer token expiration date into a datetime object
284
285        :return: expiration datetime
286        """
287        if self.token_expiry_is_time_of_expiration:
288            if not self.token_expiry_date_format:
289                raise ValueError(
290                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
291                )
292            try:
293                return ab_datetime_parse(str(value))
294            except ValueError as e:
295                raise ValueError(f"Invalid token expiry date format: {e}")
296        else:
297            try:
298                # Only accept numeric values (as int/float/string) when no format specified
299                seconds = int(float(str(value)))
300                return ab_datetime_now() + timedelta(seconds=seconds)
301            except (ValueError, TypeError):
302                raise ValueError(
303                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
304                )
305
306    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
307        """
308        Extracts the access token from the given response data.
309
310        Args:
311            response_data (Mapping[str, Any]): The response data from which to extract the access token.
312
313        Returns:
314            str: The extracted access token.
315        """
316        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
317
318    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
319        """
320        Extracts the refresh token from the given response data.
321
322        Args:
323            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
324
325        Returns:
326            str: The extracted refresh token.
327        """
328        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
329
330    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
331        """
332        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
333
334        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
335
336        Args:
337            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
338
339        Returns:
340            The extracted token_expiry_date or None if not found.
341        """
342        expires_in = self._find_and_get_value_from_response(
343            response_data, self.get_expires_in_name()
344        )
345        if expires_in is not None:
346            return self._parse_token_expiration_date(expires_in)
347
348        # expires_in is None
349        existing_expiry_date = self.get_token_expiry_date()
350        if existing_expiry_date and not self.token_has_expired():
351            return existing_expiry_date
352
353        return self._default_token_expiry_date()
354
355    def _find_and_get_value_from_response(
356        self,
357        response_data: Mapping[str, Any],
358        key_name: str,
359        max_depth: int = 5,
360        current_depth: int = 0,
361    ) -> Any:
362        """
363        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
364
365        Args:
366            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
367            key_name (str): The key to search for in the response data.
368            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
369            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
370
371        Returns:
372            Any: The value associated with the specified key if found, otherwise None.
373
374        Raises:
375            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
376        """
377        if current_depth > max_depth:
378            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
379            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
380            raise ResponseKeysMaxRecurtionReached(
381                internal_message=message, message=message, failure_type=FailureType.config_error
382            )
383
384        if isinstance(response_data, dict):
385            # get from the root level
386            if key_name in response_data:
387                return response_data[key_name]
388
389            # get from the nested object
390            for _, value in response_data.items():
391                result = self._find_and_get_value_from_response(
392                    value, key_name, max_depth, current_depth + 1
393                )
394                if result is not None:
395                    return result
396
397        # get from the nested array object
398        elif isinstance(response_data, list):
399            for item in response_data:
400                result = self._find_and_get_value_from_response(
401                    item, key_name, max_depth, current_depth + 1
402                )
403                if result is not None:
404                    return result
405
406        return None
407
408    @property
409    def _message_repository(self) -> Optional[MessageRepository]:
410        """
411        The implementation can define a message_repository if it wants debugging logs for HTTP requests
412        """
413        return _NOOP_MESSAGE_REPOSITORY
414
415    def _log_response(self, response: requests.Response) -> None:
416        """
417        Logs the HTTP response using the message repository if it is available.
418
419        Args:
420            response (requests.Response): The HTTP response to log.
421        """
422        if self._message_repository:
423            self._message_repository.log_message(
424                Level.DEBUG,
425                lambda: format_http_message(
426                    response,
427                    "Refresh token",
428                    "Obtains access token",
429                    self._NO_STREAM_NAME,
430                    is_auxiliary=True,
431                    type="AUTH",
432                ),
433            )
434
435    # ----------------
436    # ABSTR METHODS
437    # ----------------
438
439    @abstractmethod
440    def get_token_refresh_endpoint(self) -> Optional[str]:
441        """Returns the endpoint to refresh the access token"""
442
443    @abstractmethod
444    def get_client_id_name(self) -> str:
445        """The client id name to authenticate"""
446
447    @abstractmethod
448    def get_client_id(self) -> str:
449        """The client id to authenticate"""
450
451    @abstractmethod
452    def get_client_secret_name(self) -> str:
453        """The client secret name to authenticate"""
454
455    @abstractmethod
456    def get_client_secret(self) -> str:
457        """The client secret to authenticate"""
458
459    @abstractmethod
460    def get_refresh_token_name(self) -> str:
461        """The refresh token name to authenticate"""
462
463    @abstractmethod
464    def get_refresh_token(self) -> Optional[str]:
465        """The token used to refresh the access token when it expires"""
466
467    @abstractmethod
468    def get_scopes(self) -> List[str]:
469        """List of requested scopes"""
470
471    @abstractmethod
472    def get_token_expiry_date(self) -> AirbyteDateTime:
473        """Expiration date of the access token"""
474
475    @abstractmethod
476    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
477        """Setter for access token expiration date"""
478
479    @abstractmethod
480    def get_access_token_name(self) -> str:
481        """Field to extract access token from in the response"""
482
483    @abstractmethod
484    def get_expires_in_name(self) -> str:
485        """Returns the expires_in field name"""
486
487    @abstractmethod
488    def get_refresh_request_body(self) -> Mapping[str, Any]:
489        """Returns the request body to set on the refresh request"""
490
491    @abstractmethod
492    def get_refresh_request_headers(self) -> Mapping[str, Any]:
493        """Returns the request headers to set on the refresh request"""
494
495    @abstractmethod
496    def get_grant_type(self) -> str:
497        """Returns grant_type specified for requesting access_token"""
498
499    @abstractmethod
500    def get_grant_type_name(self) -> str:
501        """Returns grant_type specified name for requesting access_token"""
502
503    @property
504    @abstractmethod
505    def access_token(self) -> str:
506        """Returns the access token"""
507
508    @access_token.setter
509    @abstractmethod
510    def access_token(self, value: str) -> str:
511        """Setter for the access token"""
logger = <Logger airbyte (INFO)>
class ResponseKeysMaxRecurtionReached(airbyte_cdk.utils.traced_exception.AirbyteTracedException):
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
30    """
31    Raised when the max level of recursion is reached, when trying to
32    find-and-get the target key, during the `_make_handled_request`
33    """

Raised when the max level of recursion is reached, when trying to find-and-get the target key, during the _make_handled_request

class AbstractOauth2Authenticator(requests.auth.AuthBase):
 36class AbstractOauth2Authenticator(AuthBase):
 37    """
 38    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 39    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 40    delegating that behavior to the classes implementing the interface.
 41    """
 42
 43    _NO_STREAM_NAME = None
 44
 45    def __init__(
 46        self,
 47        refresh_token_error_status_codes: Tuple[int, ...] = (),
 48        refresh_token_error_key: str = "",
 49        refresh_token_error_values: Tuple[str, ...] = (),
 50    ) -> None:
 51        """
 52        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 53        then http errors with such params will be wrapped in AirbyteTracedException.
 54        """
 55        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 56        self._refresh_token_error_key = refresh_token_error_key
 57        self._refresh_token_error_values = refresh_token_error_values
 58
 59    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 60        """Attach the HTTP headers required to authenticate on the HTTP request"""
 61        request.headers.update(self.get_auth_header())
 62        return request
 63
 64    @property
 65    def _is_access_token_flow(self) -> bool:
 66        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 67
 68    @property
 69    def token_expiry_is_time_of_expiration(self) -> bool:
 70        """
 71        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 72        """
 73
 74        return False
 75
 76    @property
 77    def token_expiry_date_format(self) -> Optional[str]:
 78        """
 79        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 80        """
 81
 82        return None
 83
 84    def get_auth_header(self) -> Mapping[str, Any]:
 85        """HTTP header to set on the requests"""
 86        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 87        return {"Authorization": f"Bearer {token}"}
 88
 89    def get_access_token(self) -> str:
 90        """Returns the access token"""
 91        if self.token_has_expired():
 92            token, expires_in = self.refresh_access_token()
 93            self.access_token = token
 94            self.set_token_expiry_date(expires_in)
 95
 96        return self.access_token
 97
 98    def token_has_expired(self) -> bool:
 99        """Returns True if the token is expired"""
100        return ab_datetime_now() > self.get_token_expiry_date()
101
102    def build_refresh_request_body(self) -> Mapping[str, Any]:
103        """
104        Returns the request body to set on the refresh request
105
106        Override to define additional parameters
107        """
108        payload: MutableMapping[str, Any] = {
109            self.get_grant_type_name(): self.get_grant_type(),
110            self.get_client_id_name(): self.get_client_id(),
111            self.get_client_secret_name(): self.get_client_secret(),
112            self.get_refresh_token_name(): self.get_refresh_token(),
113        }
114
115        if self.get_scopes():
116            payload["scopes"] = self.get_scopes()
117
118        if self.get_refresh_request_body():
119            for key, val in self.get_refresh_request_body().items():
120                # We defer to existing oauth constructs over custom configured fields
121                if key not in payload:
122                    payload[key] = val
123
124        return payload
125
126    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
127        """
128        Returns the request headers to set on the refresh request
129
130        """
131        headers = self.get_refresh_request_headers()
132        return headers if headers else None
133
134    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
135        """
136        Returns the refresh token and its expiration datetime
137
138        :return: a tuple of (access_token, token_lifespan)
139        """
140        response_json = self._make_handled_request()
141        self._ensure_access_token_in_response(response_json)
142
143        return (
144            self._extract_access_token(response_json),
145            self._extract_token_expiry_date(response_json),
146        )
147
148    # ----------------
149    # PRIVATE METHODS
150    # ----------------
151
152    def _default_token_expiry_date(self) -> AirbyteDateTime:
153        """
154        Returns the default token expiry date
155        """
156        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
157        default_token_expiry_duration_hours = 1  # 1 hour
158        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
159
160    def _wrap_refresh_token_exception(
161        self, exception: requests.exceptions.RequestException
162    ) -> bool:
163        """
164        Wraps and handles exceptions that occur during the refresh token process.
165
166        This method checks if the provided exception is related to a refresh token error
167        by examining the response status code and specific error content.
168
169        Args:
170            exception (requests.exceptions.RequestException): The exception raised during the request.
171
172        Returns:
173            bool: True if the exception is related to a refresh token error, False otherwise.
174        """
175        try:
176            if exception.response is not None:
177                exception_content = exception.response.json()
178            else:
179                return False
180        except JSONDecodeError:
181            return False
182        return (
183            exception.response.status_code in self._refresh_token_error_status_codes
184            and exception_content.get(self._refresh_token_error_key)
185            in self._refresh_token_error_values
186        )
187
188    @backoff.on_exception(
189        backoff.expo,
190        DefaultBackoffException,
191        on_backoff=lambda details: logger.info(
192            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
193        ),
194        max_time=300,
195    )
196    def _make_handled_request(self) -> Any:
197        """
198        Makes a handled HTTP request to refresh an OAuth token.
199
200        This method sends a POST request to the token refresh endpoint with the necessary
201        headers and body to obtain a new access token. It handles various exceptions that
202        may occur during the request and logs the response for troubleshooting purposes.
203
204        Returns:
205            Mapping[str, Any]: The JSON response from the token refresh endpoint.
206
207        Raises:
208            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
209                                     or any 5xx server error.
210            AirbyteTracedException: If the refresh token is invalid or expired, prompting
211                                    re-authentication.
212            Exception: For any other exceptions that occur during the request.
213        """
214        try:
215            response = requests.request(
216                method="POST",
217                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
218                data=self.build_refresh_request_body(),
219                headers=self.build_refresh_request_headers(),
220            )
221
222            if not response.ok:
223                # log the response even if the request failed for troubleshooting purposes
224                self._log_response(response)
225                response.raise_for_status()
226
227            response_json = response.json()
228
229            try:
230                # extract the access token and add to secrets to avoid logging the raw value
231                access_key = self._extract_access_token(response_json)
232                if access_key:
233                    add_to_secrets(access_key)
234            except ResponseKeysMaxRecurtionReached as e:
235                # could not find the access token in the response, so do nothing
236                pass
237
238            self._log_response(response)
239
240            return response_json
241        except requests.exceptions.RequestException as e:
242            if e.response is not None:
243                if e.response.status_code == 429 or e.response.status_code >= 500:
244                    raise DefaultBackoffException(
245                        request=e.response.request,
246                        response=e.response,
247                        failure_type=FailureType.transient_error,
248                    )
249            if self._wrap_refresh_token_exception(e):
250                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
251                raise AirbyteTracedException(
252                    internal_message=message, message=message, failure_type=FailureType.config_error
253                )
254            raise
255        except Exception as e:
256            raise Exception(f"Error while refreshing access token: {e}") from e
257
258    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
259        """
260        Ensures that the access token is present in the response data.
261
262        This method attempts to extract the access token from the provided response data.
263        If the access token is not found, it raises an exception indicating that the token
264        refresh API response was missing the access token.
265
266        Args:
267            response_data (Mapping[str, Any]): The response data from which to extract the access token.
268
269        Raises:
270            Exception: If the access token is not found in the response data.
271            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
272        """
273        try:
274            access_key = self._extract_access_token(response_data)
275            if not access_key:
276                raise Exception(
277                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
278                )
279        except ResponseKeysMaxRecurtionReached as e:
280            raise e
281
282    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
283        """
284        Parse a string or integer token expiration date into a datetime object
285
286        :return: expiration datetime
287        """
288        if self.token_expiry_is_time_of_expiration:
289            if not self.token_expiry_date_format:
290                raise ValueError(
291                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
292                )
293            try:
294                return ab_datetime_parse(str(value))
295            except ValueError as e:
296                raise ValueError(f"Invalid token expiry date format: {e}")
297        else:
298            try:
299                # Only accept numeric values (as int/float/string) when no format specified
300                seconds = int(float(str(value)))
301                return ab_datetime_now() + timedelta(seconds=seconds)
302            except (ValueError, TypeError):
303                raise ValueError(
304                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
305                )
306
307    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
308        """
309        Extracts the access token from the given response data.
310
311        Args:
312            response_data (Mapping[str, Any]): The response data from which to extract the access token.
313
314        Returns:
315            str: The extracted access token.
316        """
317        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
318
319    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
320        """
321        Extracts the refresh token from the given response data.
322
323        Args:
324            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
325
326        Returns:
327            str: The extracted refresh token.
328        """
329        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
330
331    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
332        """
333        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
334
335        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
336
337        Args:
338            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
339
340        Returns:
341            The extracted token_expiry_date or None if not found.
342        """
343        expires_in = self._find_and_get_value_from_response(
344            response_data, self.get_expires_in_name()
345        )
346        if expires_in is not None:
347            return self._parse_token_expiration_date(expires_in)
348
349        # expires_in is None
350        existing_expiry_date = self.get_token_expiry_date()
351        if existing_expiry_date and not self.token_has_expired():
352            return existing_expiry_date
353
354        return self._default_token_expiry_date()
355
356    def _find_and_get_value_from_response(
357        self,
358        response_data: Mapping[str, Any],
359        key_name: str,
360        max_depth: int = 5,
361        current_depth: int = 0,
362    ) -> Any:
363        """
364        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
365
366        Args:
367            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
368            key_name (str): The key to search for in the response data.
369            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
370            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
371
372        Returns:
373            Any: The value associated with the specified key if found, otherwise None.
374
375        Raises:
376            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
377        """
378        if current_depth > max_depth:
379            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
380            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
381            raise ResponseKeysMaxRecurtionReached(
382                internal_message=message, message=message, failure_type=FailureType.config_error
383            )
384
385        if isinstance(response_data, dict):
386            # get from the root level
387            if key_name in response_data:
388                return response_data[key_name]
389
390            # get from the nested object
391            for _, value in response_data.items():
392                result = self._find_and_get_value_from_response(
393                    value, key_name, max_depth, current_depth + 1
394                )
395                if result is not None:
396                    return result
397
398        # get from the nested array object
399        elif isinstance(response_data, list):
400            for item in response_data:
401                result = self._find_and_get_value_from_response(
402                    item, key_name, max_depth, current_depth + 1
403                )
404                if result is not None:
405                    return result
406
407        return None
408
409    @property
410    def _message_repository(self) -> Optional[MessageRepository]:
411        """
412        The implementation can define a message_repository if it wants debugging logs for HTTP requests
413        """
414        return _NOOP_MESSAGE_REPOSITORY
415
416    def _log_response(self, response: requests.Response) -> None:
417        """
418        Logs the HTTP response using the message repository if it is available.
419
420        Args:
421            response (requests.Response): The HTTP response to log.
422        """
423        if self._message_repository:
424            self._message_repository.log_message(
425                Level.DEBUG,
426                lambda: format_http_message(
427                    response,
428                    "Refresh token",
429                    "Obtains access token",
430                    self._NO_STREAM_NAME,
431                    is_auxiliary=True,
432                    type="AUTH",
433                ),
434            )
435
436    # ----------------
437    # ABSTR METHODS
438    # ----------------
439
440    @abstractmethod
441    def get_token_refresh_endpoint(self) -> Optional[str]:
442        """Returns the endpoint to refresh the access token"""
443
444    @abstractmethod
445    def get_client_id_name(self) -> str:
446        """The client id name to authenticate"""
447
448    @abstractmethod
449    def get_client_id(self) -> str:
450        """The client id to authenticate"""
451
452    @abstractmethod
453    def get_client_secret_name(self) -> str:
454        """The client secret name to authenticate"""
455
456    @abstractmethod
457    def get_client_secret(self) -> str:
458        """The client secret to authenticate"""
459
460    @abstractmethod
461    def get_refresh_token_name(self) -> str:
462        """The refresh token name to authenticate"""
463
464    @abstractmethod
465    def get_refresh_token(self) -> Optional[str]:
466        """The token used to refresh the access token when it expires"""
467
468    @abstractmethod
469    def get_scopes(self) -> List[str]:
470        """List of requested scopes"""
471
472    @abstractmethod
473    def get_token_expiry_date(self) -> AirbyteDateTime:
474        """Expiration date of the access token"""
475
476    @abstractmethod
477    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
478        """Setter for access token expiration date"""
479
480    @abstractmethod
481    def get_access_token_name(self) -> str:
482        """Field to extract access token from in the response"""
483
484    @abstractmethod
485    def get_expires_in_name(self) -> str:
486        """Returns the expires_in field name"""
487
488    @abstractmethod
489    def get_refresh_request_body(self) -> Mapping[str, Any]:
490        """Returns the request body to set on the refresh request"""
491
492    @abstractmethod
493    def get_refresh_request_headers(self) -> Mapping[str, Any]:
494        """Returns the request headers to set on the refresh request"""
495
496    @abstractmethod
497    def get_grant_type(self) -> str:
498        """Returns grant_type specified for requesting access_token"""
499
500    @abstractmethod
501    def get_grant_type_name(self) -> str:
502        """Returns grant_type specified name for requesting access_token"""
503
504    @property
505    @abstractmethod
506    def access_token(self) -> str:
507        """Returns the access token"""
508
509    @access_token.setter
510    @abstractmethod
511    def access_token(self, value: str) -> str:
512        """Setter for the access token"""

Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.

AbstractOauth2Authenticator( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = '', refresh_token_error_values: Tuple[str, ...] = ())
45    def __init__(
46        self,
47        refresh_token_error_status_codes: Tuple[int, ...] = (),
48        refresh_token_error_key: str = "",
49        refresh_token_error_values: Tuple[str, ...] = (),
50    ) -> None:
51        """
52        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
53        then http errors with such params will be wrapped in AirbyteTracedException.
54        """
55        self._refresh_token_error_status_codes = refresh_token_error_status_codes
56        self._refresh_token_error_key = refresh_token_error_key
57        self._refresh_token_error_values = refresh_token_error_values

If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.

token_expiry_is_time_of_expiration: bool
68    @property
69    def token_expiry_is_time_of_expiration(self) -> bool:
70        """
71        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
72        """
73
74        return False

Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.

token_expiry_date_format: Optional[str]
76    @property
77    def token_expiry_date_format(self) -> Optional[str]:
78        """
79        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
80        """
81
82        return None

Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires

def get_auth_header(self) -> Mapping[str, Any]:
84    def get_auth_header(self) -> Mapping[str, Any]:
85        """HTTP header to set on the requests"""
86        token = self.access_token if self._is_access_token_flow else self.get_access_token()
87        return {"Authorization": f"Bearer {token}"}

HTTP header to set on the requests

def get_access_token(self) -> str:
89    def get_access_token(self) -> str:
90        """Returns the access token"""
91        if self.token_has_expired():
92            token, expires_in = self.refresh_access_token()
93            self.access_token = token
94            self.set_token_expiry_date(expires_in)
95
96        return self.access_token

Returns the access token

def token_has_expired(self) -> bool:
 98    def token_has_expired(self) -> bool:
 99        """Returns True if the token is expired"""
100        return ab_datetime_now() > self.get_token_expiry_date()

Returns True if the token is expired

def build_refresh_request_body(self) -> Mapping[str, Any]:
102    def build_refresh_request_body(self) -> Mapping[str, Any]:
103        """
104        Returns the request body to set on the refresh request
105
106        Override to define additional parameters
107        """
108        payload: MutableMapping[str, Any] = {
109            self.get_grant_type_name(): self.get_grant_type(),
110            self.get_client_id_name(): self.get_client_id(),
111            self.get_client_secret_name(): self.get_client_secret(),
112            self.get_refresh_token_name(): self.get_refresh_token(),
113        }
114
115        if self.get_scopes():
116            payload["scopes"] = self.get_scopes()
117
118        if self.get_refresh_request_body():
119            for key, val in self.get_refresh_request_body().items():
120                # We defer to existing oauth constructs over custom configured fields
121                if key not in payload:
122                    payload[key] = val
123
124        return payload

Returns the request body to set on the refresh request

Override to define additional parameters

def build_refresh_request_headers(self) -> Optional[Mapping[str, Any]]:
126    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
127        """
128        Returns the request headers to set on the refresh request
129
130        """
131        headers = self.get_refresh_request_headers()
132        return headers if headers else None

Returns the request headers to set on the refresh request

def refresh_access_token(self) -> Tuple[str, airbyte_cdk.utils.datetime_helpers.AirbyteDateTime]:
134    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
135        """
136        Returns the refresh token and its expiration datetime
137
138        :return: a tuple of (access_token, token_lifespan)
139        """
140        response_json = self._make_handled_request()
141        self._ensure_access_token_in_response(response_json)
142
143        return (
144            self._extract_access_token(response_json),
145            self._extract_token_expiry_date(response_json),
146        )

Returns the refresh token and its expiration datetime

Returns

a tuple of (access_token, token_lifespan)

@abstractmethod
def get_token_refresh_endpoint(self) -> Optional[str]:
440    @abstractmethod
441    def get_token_refresh_endpoint(self) -> Optional[str]:
442        """Returns the endpoint to refresh the access token"""

Returns the endpoint to refresh the access token

@abstractmethod
def get_client_id_name(self) -> str:
444    @abstractmethod
445    def get_client_id_name(self) -> str:
446        """The client id name to authenticate"""

The client id name to authenticate

@abstractmethod
def get_client_id(self) -> str:
448    @abstractmethod
449    def get_client_id(self) -> str:
450        """The client id to authenticate"""

The client id to authenticate

@abstractmethod
def get_client_secret_name(self) -> str:
452    @abstractmethod
453    def get_client_secret_name(self) -> str:
454        """The client secret name to authenticate"""

The client secret name to authenticate

@abstractmethod
def get_client_secret(self) -> str:
456    @abstractmethod
457    def get_client_secret(self) -> str:
458        """The client secret to authenticate"""

The client secret to authenticate

@abstractmethod
def get_refresh_token_name(self) -> str:
460    @abstractmethod
461    def get_refresh_token_name(self) -> str:
462        """The refresh token name to authenticate"""

The refresh token name to authenticate

@abstractmethod
def get_refresh_token(self) -> Optional[str]:
464    @abstractmethod
465    def get_refresh_token(self) -> Optional[str]:
466        """The token used to refresh the access token when it expires"""

The token used to refresh the access token when it expires

@abstractmethod
def get_scopes(self) -> List[str]:
468    @abstractmethod
469    def get_scopes(self) -> List[str]:
470        """List of requested scopes"""

List of requested scopes

@abstractmethod
def get_token_expiry_date(self) -> airbyte_cdk.utils.datetime_helpers.AirbyteDateTime:
472    @abstractmethod
473    def get_token_expiry_date(self) -> AirbyteDateTime:
474        """Expiration date of the access token"""

Expiration date of the access token

@abstractmethod
def set_token_expiry_date(self, value: airbyte_cdk.utils.datetime_helpers.AirbyteDateTime) -> None:
476    @abstractmethod
477    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
478        """Setter for access token expiration date"""

Setter for access token expiration date

@abstractmethod
def get_access_token_name(self) -> str:
480    @abstractmethod
481    def get_access_token_name(self) -> str:
482        """Field to extract access token from in the response"""

Field to extract access token from in the response

@abstractmethod
def get_expires_in_name(self) -> str:
484    @abstractmethod
485    def get_expires_in_name(self) -> str:
486        """Returns the expires_in field name"""

Returns the expires_in field name

@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
488    @abstractmethod
489    def get_refresh_request_body(self) -> Mapping[str, Any]:
490        """Returns the request body to set on the refresh request"""

Returns the request body to set on the refresh request

@abstractmethod
def get_refresh_request_headers(self) -> Mapping[str, Any]:
492    @abstractmethod
493    def get_refresh_request_headers(self) -> Mapping[str, Any]:
494        """Returns the request headers to set on the refresh request"""

Returns the request headers to set on the refresh request

@abstractmethod
def get_grant_type(self) -> str:
496    @abstractmethod
497    def get_grant_type(self) -> str:
498        """Returns grant_type specified for requesting access_token"""

Returns grant_type specified for requesting access_token

@abstractmethod
def get_grant_type_name(self) -> str:
500    @abstractmethod
501    def get_grant_type_name(self) -> str:
502        """Returns grant_type specified name for requesting access_token"""

Returns grant_type specified name for requesting access_token

access_token: str
504    @property
505    @abstractmethod
506    def access_token(self) -> str:
507        """Returns the access token"""

Returns the access token