airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth

  1#
  2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
  3#
  4
  5import logging
  6from abc import abstractmethod
  7from datetime import timedelta
  8from json import JSONDecodeError
  9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
 10
 11import backoff
 12import requests
 13from requests.auth import AuthBase
 14
 15from airbyte_cdk.models import FailureType, Level
 16from airbyte_cdk.sources.http_logger import format_http_message
 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
 18from airbyte_cdk.utils import AirbyteTracedException
 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets
 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse
 21
 22from ..exceptions import DefaultBackoffException
 23
 24logger = logging.getLogger("airbyte")
 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()
 26
 27
 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
 29    """
 30    Raised when the max level of recursion is reached, when trying to
 31    find-and-get the target key, during the `_make_handled_request`
 32    """
 33
 34
 35class AbstractOauth2Authenticator(AuthBase):
 36    """
 37    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 38    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 39    delegating that behavior to the classes implementing the interface.
 40    """
 41
 42    _NO_STREAM_NAME = None
 43
 44    def __init__(
 45        self,
 46        refresh_token_error_status_codes: Tuple[int, ...] = (),
 47        refresh_token_error_key: str = "",
 48        refresh_token_error_values: Tuple[str, ...] = (),
 49    ) -> None:
 50        """
 51        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 52        then http errors with such params will be wrapped in AirbyteTracedException.
 53        """
 54        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 55        self._refresh_token_error_key = refresh_token_error_key
 56        self._refresh_token_error_values = refresh_token_error_values
 57
 58    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 59        """Attach the HTTP headers required to authenticate on the HTTP request"""
 60        request.headers.update(self.get_auth_header())
 61        return request
 62
 63    @property
 64    def _is_access_token_flow(self) -> bool:
 65        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 66
 67    @property
 68    def token_expiry_is_time_of_expiration(self) -> bool:
 69        """
 70        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 71        """
 72
 73        return False
 74
 75    @property
 76    def token_expiry_date_format(self) -> Optional[str]:
 77        """
 78        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 79        """
 80
 81        return None
 82
 83    def get_auth_header(self) -> Mapping[str, Any]:
 84        """HTTP header to set on the requests"""
 85        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 86        return {"Authorization": f"Bearer {token}"}
 87
 88    def get_access_token(self) -> str:
 89        """Returns the access token"""
 90        if self.token_has_expired():
 91            token, expires_in = self.refresh_access_token()
 92            self.access_token = token
 93            self.set_token_expiry_date(expires_in)
 94
 95        return self.access_token
 96
 97    def token_has_expired(self) -> bool:
 98        """Returns True if the token is expired"""
 99        return ab_datetime_now() > self.get_token_expiry_date()
100
101    def build_refresh_request_body(self) -> Mapping[str, Any]:
102        """
103        Returns the request body to set on the refresh request
104
105        Override to define additional parameters
106        """
107        payload: MutableMapping[str, Any] = {
108            self.get_grant_type_name(): self.get_grant_type(),
109            self.get_client_id_name(): self.get_client_id(),
110            self.get_client_secret_name(): self.get_client_secret(),
111            self.get_refresh_token_name(): self.get_refresh_token(),
112        }
113
114        if self.get_scopes():
115            payload["scopes"] = self.get_scopes()
116
117        if self.get_refresh_request_body():
118            for key, val in self.get_refresh_request_body().items():
119                # We defer to existing oauth constructs over custom configured fields
120                if key not in payload:
121                    payload[key] = val
122
123        return payload
124
125    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
126        """
127        Returns the request headers to set on the refresh request
128
129        """
130        headers = self.get_refresh_request_headers()
131        return headers if headers else None
132
133    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
134        """
135        Returns the refresh token and its expiration datetime
136
137        :return: a tuple of (access_token, token_lifespan)
138        """
139        response_json = self._make_handled_request()
140        self._ensure_access_token_in_response(response_json)
141
142        return (
143            self._extract_access_token(response_json),
144            self._extract_token_expiry_date(response_json),
145        )
146
147    # ----------------
148    # PRIVATE METHODS
149    # ----------------
150
151    def _default_token_expiry_date(self) -> AirbyteDateTime:
152        """
153        Returns the default token expiry date
154        """
155        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
156        default_token_expiry_duration_hours = 1  # 1 hour
157        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
158
159    def _wrap_refresh_token_exception(
160        self, exception: requests.exceptions.RequestException
161    ) -> bool:
162        """
163        Wraps and handles exceptions that occur during the refresh token process.
164
165        This method checks if the provided exception is related to a refresh token error
166        by examining the response status code and specific error content.
167
168        Args:
169            exception (requests.exceptions.RequestException): The exception raised during the request.
170
171        Returns:
172            bool: True if the exception is related to a refresh token error, False otherwise.
173        """
174        try:
175            if exception.response is not None:
176                exception_content = exception.response.json()
177            else:
178                return False
179        except JSONDecodeError:
180            return False
181        return (
182            exception.response.status_code in self._refresh_token_error_status_codes
183            and exception_content.get(self._refresh_token_error_key)
184            in self._refresh_token_error_values
185        )
186
187    @backoff.on_exception(
188        backoff.expo,
189        DefaultBackoffException,
190        on_backoff=lambda details: logger.info(
191            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
192        ),
193        max_time=300,
194    )
195    def _make_handled_request(self) -> Any:
196        """
197        Makes a handled HTTP request to refresh an OAuth token.
198
199        This method sends a POST request to the token refresh endpoint with the necessary
200        headers and body to obtain a new access token. It handles various exceptions that
201        may occur during the request and logs the response for troubleshooting purposes.
202
203        Returns:
204            Mapping[str, Any]: The JSON response from the token refresh endpoint.
205
206        Raises:
207            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
208                                     or any 5xx server error.
209            AirbyteTracedException: If the refresh token is invalid or expired, prompting
210                                    re-authentication.
211            Exception: For any other exceptions that occur during the request.
212        """
213        try:
214            response = requests.request(
215                method="POST",
216                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
217                data=self.build_refresh_request_body(),
218                headers=self.build_refresh_request_headers(),
219            )
220
221            if not response.ok:
222                # log the response even if the request failed for troubleshooting purposes
223                self._log_response(response)
224                response.raise_for_status()
225
226            response_json = response.json()
227
228            try:
229                # extract the access token and add to secrets to avoid logging the raw value
230                access_key = self._extract_access_token(response_json)
231                if access_key:
232                    add_to_secrets(access_key)
233            except ResponseKeysMaxRecurtionReached as e:
234                # could not find the access token in the response, so do nothing
235                pass
236
237            self._log_response(response)
238
239            return response_json
240        except requests.exceptions.RequestException as e:
241            if e.response is not None:
242                if e.response.status_code == 429 or e.response.status_code >= 500:
243                    raise DefaultBackoffException(request=e.response.request, response=e.response)
244            if self._wrap_refresh_token_exception(e):
245                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
246                raise AirbyteTracedException(
247                    internal_message=message, message=message, failure_type=FailureType.config_error
248                )
249            raise
250        except Exception as e:
251            raise Exception(f"Error while refreshing access token: {e}") from e
252
253    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
254        """
255        Ensures that the access token is present in the response data.
256
257        This method attempts to extract the access token from the provided response data.
258        If the access token is not found, it raises an exception indicating that the token
259        refresh API response was missing the access token.
260
261        Args:
262            response_data (Mapping[str, Any]): The response data from which to extract the access token.
263
264        Raises:
265            Exception: If the access token is not found in the response data.
266            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
267        """
268        try:
269            access_key = self._extract_access_token(response_data)
270            if not access_key:
271                raise Exception(
272                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
273                )
274        except ResponseKeysMaxRecurtionReached as e:
275            raise e
276
277    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
278        """
279        Parse a string or integer token expiration date into a datetime object
280
281        :return: expiration datetime
282        """
283        if self.token_expiry_is_time_of_expiration:
284            if not self.token_expiry_date_format:
285                raise ValueError(
286                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
287                )
288            try:
289                return ab_datetime_parse(str(value))
290            except ValueError as e:
291                raise ValueError(f"Invalid token expiry date format: {e}")
292        else:
293            try:
294                # Only accept numeric values (as int/float/string) when no format specified
295                seconds = int(float(str(value)))
296                return ab_datetime_now() + timedelta(seconds=seconds)
297            except (ValueError, TypeError):
298                raise ValueError(
299                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
300                )
301
302    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
303        """
304        Extracts the access token from the given response data.
305
306        Args:
307            response_data (Mapping[str, Any]): The response data from which to extract the access token.
308
309        Returns:
310            str: The extracted access token.
311        """
312        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
313
314    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
315        """
316        Extracts the refresh token from the given response data.
317
318        Args:
319            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
320
321        Returns:
322            str: The extracted refresh token.
323        """
324        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
325
326    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
327        """
328        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
329
330        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
331
332        Args:
333            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
334
335        Returns:
336            The extracted token_expiry_date or None if not found.
337        """
338        expires_in = self._find_and_get_value_from_response(
339            response_data, self.get_expires_in_name()
340        )
341        if expires_in is not None:
342            return self._parse_token_expiration_date(expires_in)
343
344        # expires_in is None
345        existing_expiry_date = self.get_token_expiry_date()
346        if existing_expiry_date and not self.token_has_expired():
347            return existing_expiry_date
348
349        return self._default_token_expiry_date()
350
351    def _find_and_get_value_from_response(
352        self,
353        response_data: Mapping[str, Any],
354        key_name: str,
355        max_depth: int = 5,
356        current_depth: int = 0,
357    ) -> Any:
358        """
359        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
360
361        Args:
362            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
363            key_name (str): The key to search for in the response data.
364            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
365            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
366
367        Returns:
368            Any: The value associated with the specified key if found, otherwise None.
369
370        Raises:
371            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
372        """
373        if current_depth > max_depth:
374            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
375            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
376            raise ResponseKeysMaxRecurtionReached(
377                internal_message=message, message=message, failure_type=FailureType.config_error
378            )
379
380        if isinstance(response_data, dict):
381            # get from the root level
382            if key_name in response_data:
383                return response_data[key_name]
384
385            # get from the nested object
386            for _, value in response_data.items():
387                result = self._find_and_get_value_from_response(
388                    value, key_name, max_depth, current_depth + 1
389                )
390                if result is not None:
391                    return result
392
393        # get from the nested array object
394        elif isinstance(response_data, list):
395            for item in response_data:
396                result = self._find_and_get_value_from_response(
397                    item, key_name, max_depth, current_depth + 1
398                )
399                if result is not None:
400                    return result
401
402        return None
403
404    @property
405    def _message_repository(self) -> Optional[MessageRepository]:
406        """
407        The implementation can define a message_repository if it wants debugging logs for HTTP requests
408        """
409        return _NOOP_MESSAGE_REPOSITORY
410
411    def _log_response(self, response: requests.Response) -> None:
412        """
413        Logs the HTTP response using the message repository if it is available.
414
415        Args:
416            response (requests.Response): The HTTP response to log.
417        """
418        if self._message_repository:
419            self._message_repository.log_message(
420                Level.DEBUG,
421                lambda: format_http_message(
422                    response,
423                    "Refresh token",
424                    "Obtains access token",
425                    self._NO_STREAM_NAME,
426                    is_auxiliary=True,
427                    type="AUTH",
428                ),
429            )
430
431    # ----------------
432    # ABSTR METHODS
433    # ----------------
434
435    @abstractmethod
436    def get_token_refresh_endpoint(self) -> Optional[str]:
437        """Returns the endpoint to refresh the access token"""
438
439    @abstractmethod
440    def get_client_id_name(self) -> str:
441        """The client id name to authenticate"""
442
443    @abstractmethod
444    def get_client_id(self) -> str:
445        """The client id to authenticate"""
446
447    @abstractmethod
448    def get_client_secret_name(self) -> str:
449        """The client secret name to authenticate"""
450
451    @abstractmethod
452    def get_client_secret(self) -> str:
453        """The client secret to authenticate"""
454
455    @abstractmethod
456    def get_refresh_token_name(self) -> str:
457        """The refresh token name to authenticate"""
458
459    @abstractmethod
460    def get_refresh_token(self) -> Optional[str]:
461        """The token used to refresh the access token when it expires"""
462
463    @abstractmethod
464    def get_scopes(self) -> List[str]:
465        """List of requested scopes"""
466
467    @abstractmethod
468    def get_token_expiry_date(self) -> AirbyteDateTime:
469        """Expiration date of the access token"""
470
471    @abstractmethod
472    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
473        """Setter for access token expiration date"""
474
475    @abstractmethod
476    def get_access_token_name(self) -> str:
477        """Field to extract access token from in the response"""
478
479    @abstractmethod
480    def get_expires_in_name(self) -> str:
481        """Returns the expires_in field name"""
482
483    @abstractmethod
484    def get_refresh_request_body(self) -> Mapping[str, Any]:
485        """Returns the request body to set on the refresh request"""
486
487    @abstractmethod
488    def get_refresh_request_headers(self) -> Mapping[str, Any]:
489        """Returns the request headers to set on the refresh request"""
490
491    @abstractmethod
492    def get_grant_type(self) -> str:
493        """Returns grant_type specified for requesting access_token"""
494
495    @abstractmethod
496    def get_grant_type_name(self) -> str:
497        """Returns grant_type specified name for requesting access_token"""
498
499    @property
500    @abstractmethod
501    def access_token(self) -> str:
502        """Returns the access token"""
503
504    @access_token.setter
505    @abstractmethod
506    def access_token(self, value: str) -> str:
507        """Setter for the access token"""
logger = <Logger airbyte (INFO)>
class ResponseKeysMaxRecurtionReached(airbyte_cdk.utils.traced_exception.AirbyteTracedException):
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
30    """
31    Raised when the max level of recursion is reached, when trying to
32    find-and-get the target key, during the `_make_handled_request`
33    """

Raised when the max level of recursion is reached, when trying to find-and-get the target key, during the _make_handled_request

class AbstractOauth2Authenticator(requests.auth.AuthBase):
 36class AbstractOauth2Authenticator(AuthBase):
 37    """
 38    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 39    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 40    delegating that behavior to the classes implementing the interface.
 41    """
 42
 43    _NO_STREAM_NAME = None
 44
 45    def __init__(
 46        self,
 47        refresh_token_error_status_codes: Tuple[int, ...] = (),
 48        refresh_token_error_key: str = "",
 49        refresh_token_error_values: Tuple[str, ...] = (),
 50    ) -> None:
 51        """
 52        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 53        then http errors with such params will be wrapped in AirbyteTracedException.
 54        """
 55        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 56        self._refresh_token_error_key = refresh_token_error_key
 57        self._refresh_token_error_values = refresh_token_error_values
 58
 59    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 60        """Attach the HTTP headers required to authenticate on the HTTP request"""
 61        request.headers.update(self.get_auth_header())
 62        return request
 63
 64    @property
 65    def _is_access_token_flow(self) -> bool:
 66        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 67
 68    @property
 69    def token_expiry_is_time_of_expiration(self) -> bool:
 70        """
 71        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 72        """
 73
 74        return False
 75
 76    @property
 77    def token_expiry_date_format(self) -> Optional[str]:
 78        """
 79        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 80        """
 81
 82        return None
 83
 84    def get_auth_header(self) -> Mapping[str, Any]:
 85        """HTTP header to set on the requests"""
 86        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 87        return {"Authorization": f"Bearer {token}"}
 88
 89    def get_access_token(self) -> str:
 90        """Returns the access token"""
 91        if self.token_has_expired():
 92            token, expires_in = self.refresh_access_token()
 93            self.access_token = token
 94            self.set_token_expiry_date(expires_in)
 95
 96        return self.access_token
 97
 98    def token_has_expired(self) -> bool:
 99        """Returns True if the token is expired"""
100        return ab_datetime_now() > self.get_token_expiry_date()
101
102    def build_refresh_request_body(self) -> Mapping[str, Any]:
103        """
104        Returns the request body to set on the refresh request
105
106        Override to define additional parameters
107        """
108        payload: MutableMapping[str, Any] = {
109            self.get_grant_type_name(): self.get_grant_type(),
110            self.get_client_id_name(): self.get_client_id(),
111            self.get_client_secret_name(): self.get_client_secret(),
112            self.get_refresh_token_name(): self.get_refresh_token(),
113        }
114
115        if self.get_scopes():
116            payload["scopes"] = self.get_scopes()
117
118        if self.get_refresh_request_body():
119            for key, val in self.get_refresh_request_body().items():
120                # We defer to existing oauth constructs over custom configured fields
121                if key not in payload:
122                    payload[key] = val
123
124        return payload
125
126    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
127        """
128        Returns the request headers to set on the refresh request
129
130        """
131        headers = self.get_refresh_request_headers()
132        return headers if headers else None
133
134    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
135        """
136        Returns the refresh token and its expiration datetime
137
138        :return: a tuple of (access_token, token_lifespan)
139        """
140        response_json = self._make_handled_request()
141        self._ensure_access_token_in_response(response_json)
142
143        return (
144            self._extract_access_token(response_json),
145            self._extract_token_expiry_date(response_json),
146        )
147
148    # ----------------
149    # PRIVATE METHODS
150    # ----------------
151
152    def _default_token_expiry_date(self) -> AirbyteDateTime:
153        """
154        Returns the default token expiry date
155        """
156        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
157        default_token_expiry_duration_hours = 1  # 1 hour
158        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
159
160    def _wrap_refresh_token_exception(
161        self, exception: requests.exceptions.RequestException
162    ) -> bool:
163        """
164        Wraps and handles exceptions that occur during the refresh token process.
165
166        This method checks if the provided exception is related to a refresh token error
167        by examining the response status code and specific error content.
168
169        Args:
170            exception (requests.exceptions.RequestException): The exception raised during the request.
171
172        Returns:
173            bool: True if the exception is related to a refresh token error, False otherwise.
174        """
175        try:
176            if exception.response is not None:
177                exception_content = exception.response.json()
178            else:
179                return False
180        except JSONDecodeError:
181            return False
182        return (
183            exception.response.status_code in self._refresh_token_error_status_codes
184            and exception_content.get(self._refresh_token_error_key)
185            in self._refresh_token_error_values
186        )
187
188    @backoff.on_exception(
189        backoff.expo,
190        DefaultBackoffException,
191        on_backoff=lambda details: logger.info(
192            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
193        ),
194        max_time=300,
195    )
196    def _make_handled_request(self) -> Any:
197        """
198        Makes a handled HTTP request to refresh an OAuth token.
199
200        This method sends a POST request to the token refresh endpoint with the necessary
201        headers and body to obtain a new access token. It handles various exceptions that
202        may occur during the request and logs the response for troubleshooting purposes.
203
204        Returns:
205            Mapping[str, Any]: The JSON response from the token refresh endpoint.
206
207        Raises:
208            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
209                                     or any 5xx server error.
210            AirbyteTracedException: If the refresh token is invalid or expired, prompting
211                                    re-authentication.
212            Exception: For any other exceptions that occur during the request.
213        """
214        try:
215            response = requests.request(
216                method="POST",
217                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
218                data=self.build_refresh_request_body(),
219                headers=self.build_refresh_request_headers(),
220            )
221
222            if not response.ok:
223                # log the response even if the request failed for troubleshooting purposes
224                self._log_response(response)
225                response.raise_for_status()
226
227            response_json = response.json()
228
229            try:
230                # extract the access token and add to secrets to avoid logging the raw value
231                access_key = self._extract_access_token(response_json)
232                if access_key:
233                    add_to_secrets(access_key)
234            except ResponseKeysMaxRecurtionReached as e:
235                # could not find the access token in the response, so do nothing
236                pass
237
238            self._log_response(response)
239
240            return response_json
241        except requests.exceptions.RequestException as e:
242            if e.response is not None:
243                if e.response.status_code == 429 or e.response.status_code >= 500:
244                    raise DefaultBackoffException(request=e.response.request, response=e.response)
245            if self._wrap_refresh_token_exception(e):
246                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
247                raise AirbyteTracedException(
248                    internal_message=message, message=message, failure_type=FailureType.config_error
249                )
250            raise
251        except Exception as e:
252            raise Exception(f"Error while refreshing access token: {e}") from e
253
254    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
255        """
256        Ensures that the access token is present in the response data.
257
258        This method attempts to extract the access token from the provided response data.
259        If the access token is not found, it raises an exception indicating that the token
260        refresh API response was missing the access token.
261
262        Args:
263            response_data (Mapping[str, Any]): The response data from which to extract the access token.
264
265        Raises:
266            Exception: If the access token is not found in the response data.
267            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
268        """
269        try:
270            access_key = self._extract_access_token(response_data)
271            if not access_key:
272                raise Exception(
273                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
274                )
275        except ResponseKeysMaxRecurtionReached as e:
276            raise e
277
278    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
279        """
280        Parse a string or integer token expiration date into a datetime object
281
282        :return: expiration datetime
283        """
284        if self.token_expiry_is_time_of_expiration:
285            if not self.token_expiry_date_format:
286                raise ValueError(
287                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
288                )
289            try:
290                return ab_datetime_parse(str(value))
291            except ValueError as e:
292                raise ValueError(f"Invalid token expiry date format: {e}")
293        else:
294            try:
295                # Only accept numeric values (as int/float/string) when no format specified
296                seconds = int(float(str(value)))
297                return ab_datetime_now() + timedelta(seconds=seconds)
298            except (ValueError, TypeError):
299                raise ValueError(
300                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
301                )
302
303    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
304        """
305        Extracts the access token from the given response data.
306
307        Args:
308            response_data (Mapping[str, Any]): The response data from which to extract the access token.
309
310        Returns:
311            str: The extracted access token.
312        """
313        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
314
315    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
316        """
317        Extracts the refresh token from the given response data.
318
319        Args:
320            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
321
322        Returns:
323            str: The extracted refresh token.
324        """
325        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
326
327    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
328        """
329        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
330
331        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
332
333        Args:
334            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
335
336        Returns:
337            The extracted token_expiry_date or None if not found.
338        """
339        expires_in = self._find_and_get_value_from_response(
340            response_data, self.get_expires_in_name()
341        )
342        if expires_in is not None:
343            return self._parse_token_expiration_date(expires_in)
344
345        # expires_in is None
346        existing_expiry_date = self.get_token_expiry_date()
347        if existing_expiry_date and not self.token_has_expired():
348            return existing_expiry_date
349
350        return self._default_token_expiry_date()
351
352    def _find_and_get_value_from_response(
353        self,
354        response_data: Mapping[str, Any],
355        key_name: str,
356        max_depth: int = 5,
357        current_depth: int = 0,
358    ) -> Any:
359        """
360        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
361
362        Args:
363            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
364            key_name (str): The key to search for in the response data.
365            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
366            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
367
368        Returns:
369            Any: The value associated with the specified key if found, otherwise None.
370
371        Raises:
372            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
373        """
374        if current_depth > max_depth:
375            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
376            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
377            raise ResponseKeysMaxRecurtionReached(
378                internal_message=message, message=message, failure_type=FailureType.config_error
379            )
380
381        if isinstance(response_data, dict):
382            # get from the root level
383            if key_name in response_data:
384                return response_data[key_name]
385
386            # get from the nested object
387            for _, value in response_data.items():
388                result = self._find_and_get_value_from_response(
389                    value, key_name, max_depth, current_depth + 1
390                )
391                if result is not None:
392                    return result
393
394        # get from the nested array object
395        elif isinstance(response_data, list):
396            for item in response_data:
397                result = self._find_and_get_value_from_response(
398                    item, key_name, max_depth, current_depth + 1
399                )
400                if result is not None:
401                    return result
402
403        return None
404
405    @property
406    def _message_repository(self) -> Optional[MessageRepository]:
407        """
408        The implementation can define a message_repository if it wants debugging logs for HTTP requests
409        """
410        return _NOOP_MESSAGE_REPOSITORY
411
412    def _log_response(self, response: requests.Response) -> None:
413        """
414        Logs the HTTP response using the message repository if it is available.
415
416        Args:
417            response (requests.Response): The HTTP response to log.
418        """
419        if self._message_repository:
420            self._message_repository.log_message(
421                Level.DEBUG,
422                lambda: format_http_message(
423                    response,
424                    "Refresh token",
425                    "Obtains access token",
426                    self._NO_STREAM_NAME,
427                    is_auxiliary=True,
428                    type="AUTH",
429                ),
430            )
431
432    # ----------------
433    # ABSTR METHODS
434    # ----------------
435
436    @abstractmethod
437    def get_token_refresh_endpoint(self) -> Optional[str]:
438        """Returns the endpoint to refresh the access token"""
439
440    @abstractmethod
441    def get_client_id_name(self) -> str:
442        """The client id name to authenticate"""
443
444    @abstractmethod
445    def get_client_id(self) -> str:
446        """The client id to authenticate"""
447
448    @abstractmethod
449    def get_client_secret_name(self) -> str:
450        """The client secret name to authenticate"""
451
452    @abstractmethod
453    def get_client_secret(self) -> str:
454        """The client secret to authenticate"""
455
456    @abstractmethod
457    def get_refresh_token_name(self) -> str:
458        """The refresh token name to authenticate"""
459
460    @abstractmethod
461    def get_refresh_token(self) -> Optional[str]:
462        """The token used to refresh the access token when it expires"""
463
464    @abstractmethod
465    def get_scopes(self) -> List[str]:
466        """List of requested scopes"""
467
468    @abstractmethod
469    def get_token_expiry_date(self) -> AirbyteDateTime:
470        """Expiration date of the access token"""
471
472    @abstractmethod
473    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
474        """Setter for access token expiration date"""
475
476    @abstractmethod
477    def get_access_token_name(self) -> str:
478        """Field to extract access token from in the response"""
479
480    @abstractmethod
481    def get_expires_in_name(self) -> str:
482        """Returns the expires_in field name"""
483
484    @abstractmethod
485    def get_refresh_request_body(self) -> Mapping[str, Any]:
486        """Returns the request body to set on the refresh request"""
487
488    @abstractmethod
489    def get_refresh_request_headers(self) -> Mapping[str, Any]:
490        """Returns the request headers to set on the refresh request"""
491
492    @abstractmethod
493    def get_grant_type(self) -> str:
494        """Returns grant_type specified for requesting access_token"""
495
496    @abstractmethod
497    def get_grant_type_name(self) -> str:
498        """Returns grant_type specified name for requesting access_token"""
499
500    @property
501    @abstractmethod
502    def access_token(self) -> str:
503        """Returns the access token"""
504
505    @access_token.setter
506    @abstractmethod
507    def access_token(self, value: str) -> str:
508        """Setter for the access token"""

Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.

AbstractOauth2Authenticator( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = '', refresh_token_error_values: Tuple[str, ...] = ())
45    def __init__(
46        self,
47        refresh_token_error_status_codes: Tuple[int, ...] = (),
48        refresh_token_error_key: str = "",
49        refresh_token_error_values: Tuple[str, ...] = (),
50    ) -> None:
51        """
52        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
53        then http errors with such params will be wrapped in AirbyteTracedException.
54        """
55        self._refresh_token_error_status_codes = refresh_token_error_status_codes
56        self._refresh_token_error_key = refresh_token_error_key
57        self._refresh_token_error_values = refresh_token_error_values

If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.

token_expiry_is_time_of_expiration: bool
68    @property
69    def token_expiry_is_time_of_expiration(self) -> bool:
70        """
71        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
72        """
73
74        return False

Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.

token_expiry_date_format: Optional[str]
76    @property
77    def token_expiry_date_format(self) -> Optional[str]:
78        """
79        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
80        """
81
82        return None

Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires

def get_auth_header(self) -> Mapping[str, Any]:
84    def get_auth_header(self) -> Mapping[str, Any]:
85        """HTTP header to set on the requests"""
86        token = self.access_token if self._is_access_token_flow else self.get_access_token()
87        return {"Authorization": f"Bearer {token}"}

HTTP header to set on the requests

def get_access_token(self) -> str:
89    def get_access_token(self) -> str:
90        """Returns the access token"""
91        if self.token_has_expired():
92            token, expires_in = self.refresh_access_token()
93            self.access_token = token
94            self.set_token_expiry_date(expires_in)
95
96        return self.access_token

Returns the access token

def token_has_expired(self) -> bool:
 98    def token_has_expired(self) -> bool:
 99        """Returns True if the token is expired"""
100        return ab_datetime_now() > self.get_token_expiry_date()

Returns True if the token is expired

def build_refresh_request_body(self) -> Mapping[str, Any]:
102    def build_refresh_request_body(self) -> Mapping[str, Any]:
103        """
104        Returns the request body to set on the refresh request
105
106        Override to define additional parameters
107        """
108        payload: MutableMapping[str, Any] = {
109            self.get_grant_type_name(): self.get_grant_type(),
110            self.get_client_id_name(): self.get_client_id(),
111            self.get_client_secret_name(): self.get_client_secret(),
112            self.get_refresh_token_name(): self.get_refresh_token(),
113        }
114
115        if self.get_scopes():
116            payload["scopes"] = self.get_scopes()
117
118        if self.get_refresh_request_body():
119            for key, val in self.get_refresh_request_body().items():
120                # We defer to existing oauth constructs over custom configured fields
121                if key not in payload:
122                    payload[key] = val
123
124        return payload

Returns the request body to set on the refresh request

Override to define additional parameters

def build_refresh_request_headers(self) -> Optional[Mapping[str, Any]]:
126    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
127        """
128        Returns the request headers to set on the refresh request
129
130        """
131        headers = self.get_refresh_request_headers()
132        return headers if headers else None

Returns the request headers to set on the refresh request

def refresh_access_token(self) -> Tuple[str, airbyte_cdk.utils.datetime_helpers.AirbyteDateTime]:
134    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
135        """
136        Returns the refresh token and its expiration datetime
137
138        :return: a tuple of (access_token, token_lifespan)
139        """
140        response_json = self._make_handled_request()
141        self._ensure_access_token_in_response(response_json)
142
143        return (
144            self._extract_access_token(response_json),
145            self._extract_token_expiry_date(response_json),
146        )

Returns the refresh token and its expiration datetime

Returns

a tuple of (access_token, token_lifespan)

@abstractmethod
def get_token_refresh_endpoint(self) -> Optional[str]:
436    @abstractmethod
437    def get_token_refresh_endpoint(self) -> Optional[str]:
438        """Returns the endpoint to refresh the access token"""

Returns the endpoint to refresh the access token

@abstractmethod
def get_client_id_name(self) -> str:
440    @abstractmethod
441    def get_client_id_name(self) -> str:
442        """The client id name to authenticate"""

The client id name to authenticate

@abstractmethod
def get_client_id(self) -> str:
444    @abstractmethod
445    def get_client_id(self) -> str:
446        """The client id to authenticate"""

The client id to authenticate

@abstractmethod
def get_client_secret_name(self) -> str:
448    @abstractmethod
449    def get_client_secret_name(self) -> str:
450        """The client secret name to authenticate"""

The client secret name to authenticate

@abstractmethod
def get_client_secret(self) -> str:
452    @abstractmethod
453    def get_client_secret(self) -> str:
454        """The client secret to authenticate"""

The client secret to authenticate

@abstractmethod
def get_refresh_token_name(self) -> str:
456    @abstractmethod
457    def get_refresh_token_name(self) -> str:
458        """The refresh token name to authenticate"""

The refresh token name to authenticate

@abstractmethod
def get_refresh_token(self) -> Optional[str]:
460    @abstractmethod
461    def get_refresh_token(self) -> Optional[str]:
462        """The token used to refresh the access token when it expires"""

The token used to refresh the access token when it expires

@abstractmethod
def get_scopes(self) -> List[str]:
464    @abstractmethod
465    def get_scopes(self) -> List[str]:
466        """List of requested scopes"""

List of requested scopes

@abstractmethod
def get_token_expiry_date(self) -> airbyte_cdk.utils.datetime_helpers.AirbyteDateTime:
468    @abstractmethod
469    def get_token_expiry_date(self) -> AirbyteDateTime:
470        """Expiration date of the access token"""

Expiration date of the access token

@abstractmethod
def set_token_expiry_date(self, value: airbyte_cdk.utils.datetime_helpers.AirbyteDateTime) -> None:
472    @abstractmethod
473    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
474        """Setter for access token expiration date"""

Setter for access token expiration date

@abstractmethod
def get_access_token_name(self) -> str:
476    @abstractmethod
477    def get_access_token_name(self) -> str:
478        """Field to extract access token from in the response"""

Field to extract access token from in the response

@abstractmethod
def get_expires_in_name(self) -> str:
480    @abstractmethod
481    def get_expires_in_name(self) -> str:
482        """Returns the expires_in field name"""

Returns the expires_in field name

@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
484    @abstractmethod
485    def get_refresh_request_body(self) -> Mapping[str, Any]:
486        """Returns the request body to set on the refresh request"""

Returns the request body to set on the refresh request

@abstractmethod
def get_refresh_request_headers(self) -> Mapping[str, Any]:
488    @abstractmethod
489    def get_refresh_request_headers(self) -> Mapping[str, Any]:
490        """Returns the request headers to set on the refresh request"""

Returns the request headers to set on the refresh request

@abstractmethod
def get_grant_type(self) -> str:
492    @abstractmethod
493    def get_grant_type(self) -> str:
494        """Returns grant_type specified for requesting access_token"""

Returns grant_type specified for requesting access_token

@abstractmethod
def get_grant_type_name(self) -> str:
496    @abstractmethod
497    def get_grant_type_name(self) -> str:
498        """Returns grant_type specified name for requesting access_token"""

Returns grant_type specified name for requesting access_token

access_token: str
500    @property
501    @abstractmethod
502    def access_token(self) -> str:
503        """Returns the access token"""

Returns the access token