airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth

  1#
  2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
  3#
  4
  5import logging
  6from abc import abstractmethod
  7from datetime import timedelta
  8from json import JSONDecodeError
  9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
 10
 11import backoff
 12import requests
 13from requests.auth import AuthBase
 14
 15from airbyte_cdk.models import FailureType, Level
 16from airbyte_cdk.sources.http_logger import format_http_message
 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
 18from airbyte_cdk.utils import AirbyteTracedException
 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets
 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse
 21
 22from ..exceptions import DefaultBackoffException
 23
 24logger = logging.getLogger("airbyte")
 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()
 26
 27
 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
 29    """
 30    Raised when the max level of recursion is reached, when trying to
 31    find-and-get the target key, during the `_make_handled_request`
 32    """
 33
 34
 35class AbstractOauth2Authenticator(AuthBase):
 36    """
 37    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 38    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 39    delegating that behavior to the classes implementing the interface.
 40    """
 41
 42    _NO_STREAM_NAME = None
 43
 44    def __init__(
 45        self,
 46        refresh_token_error_status_codes: Tuple[int, ...] = (),
 47        refresh_token_error_key: str = "",
 48        refresh_token_error_values: Tuple[str, ...] = (),
 49    ) -> None:
 50        """
 51        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 52        then http errors with such params will be wrapped in AirbyteTracedException.
 53        """
 54        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 55        self._refresh_token_error_key = refresh_token_error_key
 56        self._refresh_token_error_values = refresh_token_error_values
 57
 58    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 59        """Attach the HTTP headers required to authenticate on the HTTP request"""
 60        request.headers.update(self.get_auth_header())
 61        return request
 62
 63    @property
 64    def _is_access_token_flow(self) -> bool:
 65        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 66
 67    @property
 68    def token_expiry_is_time_of_expiration(self) -> bool:
 69        """
 70        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 71        """
 72
 73        return False
 74
 75    @property
 76    def token_expiry_date_format(self) -> Optional[str]:
 77        """
 78        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 79        """
 80
 81        return None
 82
 83    def get_auth_header(self) -> Mapping[str, Any]:
 84        """HTTP header to set on the requests"""
 85        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 86        return {"Authorization": f"Bearer {token}"}
 87
 88    def get_access_token(self) -> str:
 89        """Returns the access token"""
 90        if self.token_has_expired():
 91            token, expires_in = self.refresh_access_token()
 92            self.access_token = token
 93            self.set_token_expiry_date(expires_in)
 94
 95        return self.access_token
 96
 97    def token_has_expired(self) -> bool:
 98        """Returns True if the token is expired"""
 99        return ab_datetime_now() > self.get_token_expiry_date()
100
101    def build_refresh_request_body(self) -> Mapping[str, Any]:
102        """
103        Returns the request body to set on the refresh request.
104
105        Override to define additional parameters.
106
107        Client credentials (client_id and client_secret) are excluded from the body when
108        refresh_request_headers contains an Authorization header (e.g., Basic auth).
109        This is required by OAuth providers like Gong that expect credentials ONLY in the
110        Authorization header and reject requests that include them in both places.
111        """
112        # Check if credentials are being sent via Authorization header
113        headers = self.get_refresh_request_headers()
114        credentials_in_header = headers and "Authorization" in headers
115
116        # Only include client credentials in body if not already in header
117        include_client_credentials = not credentials_in_header
118
119        payload: MutableMapping[str, Any] = {
120            self.get_grant_type_name(): self.get_grant_type(),
121        }
122
123        # Only include client credentials in body if configured to do so and not in header
124        if include_client_credentials:
125            payload[self.get_client_id_name()] = self.get_client_id()
126            payload[self.get_client_secret_name()] = self.get_client_secret()
127
128        payload[self.get_refresh_token_name()] = self.get_refresh_token()
129
130        if self.get_scopes():
131            payload["scopes"] = self.get_scopes()
132
133        if self.get_refresh_request_body():
134            for key, val in self.get_refresh_request_body().items():
135                # We defer to existing oauth constructs over custom configured fields
136                if key not in payload:
137                    payload[key] = val
138
139        return payload
140
141    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
142        """
143        Returns the request headers to set on the refresh request
144
145        """
146        headers = self.get_refresh_request_headers()
147        return headers if headers else None
148
149    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
150        """
151        Returns the refresh token and its expiration datetime
152
153        :return: a tuple of (access_token, token_lifespan)
154        """
155        response_json = self._make_handled_request()
156        self._ensure_access_token_in_response(response_json)
157
158        return (
159            self._extract_access_token(response_json),
160            self._extract_token_expiry_date(response_json),
161        )
162
163    # ----------------
164    # PRIVATE METHODS
165    # ----------------
166
167    def _default_token_expiry_date(self) -> AirbyteDateTime:
168        """
169        Returns the default token expiry date
170        """
171        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
172        default_token_expiry_duration_hours = 1  # 1 hour
173        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
174
175    def _wrap_refresh_token_exception(
176        self, exception: requests.exceptions.RequestException
177    ) -> bool:
178        """
179        Wraps and handles exceptions that occur during the refresh token process.
180
181        This method checks if the provided exception is related to a refresh token error
182        by examining the response status code and specific error content.
183
184        Args:
185            exception (requests.exceptions.RequestException): The exception raised during the request.
186
187        Returns:
188            bool: True if the exception is related to a refresh token error, False otherwise.
189        """
190        try:
191            if exception.response is not None:
192                exception_content = exception.response.json()
193            else:
194                return False
195        except JSONDecodeError:
196            return False
197        return (
198            exception.response.status_code in self._refresh_token_error_status_codes
199            and exception_content.get(self._refresh_token_error_key)
200            in self._refresh_token_error_values
201        )
202
203    @backoff.on_exception(
204        backoff.expo,
205        DefaultBackoffException,
206        on_backoff=lambda details: logger.info(
207            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
208        ),
209        max_time=300,
210    )
211    def _make_handled_request(self) -> Any:
212        """
213        Makes a handled HTTP request to refresh an OAuth token.
214
215        This method sends a POST request to the token refresh endpoint with the necessary
216        headers and body to obtain a new access token. It handles various exceptions that
217        may occur during the request and logs the response for troubleshooting purposes.
218
219        Returns:
220            Mapping[str, Any]: The JSON response from the token refresh endpoint.
221
222        Raises:
223            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
224                                     or any 5xx server error.
225            AirbyteTracedException: If the refresh token is invalid or expired, prompting
226                                    re-authentication.
227            Exception: For any other exceptions that occur during the request.
228        """
229        try:
230            response = requests.request(
231                method="POST",
232                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
233                data=self.build_refresh_request_body(),
234                headers=self.build_refresh_request_headers(),
235            )
236
237            if not response.ok:
238                # log the response even if the request failed for troubleshooting purposes
239                self._log_response(response)
240                response.raise_for_status()
241
242            response_json = response.json()
243
244            try:
245                # extract the access token and add to secrets to avoid logging the raw value
246                access_key = self._extract_access_token(response_json)
247                if access_key:
248                    add_to_secrets(access_key)
249            except ResponseKeysMaxRecurtionReached as e:
250                # could not find the access token in the response, so do nothing
251                pass
252
253            self._log_response(response)
254
255            return response_json
256        except requests.exceptions.RequestException as e:
257            if e.response is not None:
258                if e.response.status_code == 429 or e.response.status_code >= 500:
259                    raise DefaultBackoffException(
260                        request=e.response.request,
261                        response=e.response,
262                        failure_type=FailureType.transient_error,
263                    )
264            if self._wrap_refresh_token_exception(e):
265                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
266                raise AirbyteTracedException(
267                    internal_message=message, message=message, failure_type=FailureType.config_error
268                )
269            raise
270        except Exception as e:
271            raise Exception(f"Error while refreshing access token: {e}") from e
272
273    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
274        """
275        Ensures that the access token is present in the response data.
276
277        This method attempts to extract the access token from the provided response data.
278        If the access token is not found, it raises an exception indicating that the token
279        refresh API response was missing the access token.
280
281        Args:
282            response_data (Mapping[str, Any]): The response data from which to extract the access token.
283
284        Raises:
285            Exception: If the access token is not found in the response data.
286            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
287        """
288        try:
289            access_key = self._extract_access_token(response_data)
290            if not access_key:
291                raise Exception(
292                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
293                )
294        except ResponseKeysMaxRecurtionReached as e:
295            raise e
296
297    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
298        """
299        Parse a string or integer token expiration date into a datetime object
300
301        :return: expiration datetime
302        """
303        if self.token_expiry_is_time_of_expiration:
304            if not self.token_expiry_date_format:
305                raise ValueError(
306                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
307                )
308            try:
309                return ab_datetime_parse(str(value))
310            except ValueError as e:
311                raise ValueError(f"Invalid token expiry date format: {e}")
312        else:
313            try:
314                # Only accept numeric values (as int/float/string) when no format specified
315                seconds = int(float(str(value)))
316                return ab_datetime_now() + timedelta(seconds=seconds)
317            except (ValueError, TypeError):
318                raise ValueError(
319                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
320                )
321
322    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
323        """
324        Extracts the access token from the given response data.
325
326        Args:
327            response_data (Mapping[str, Any]): The response data from which to extract the access token.
328
329        Returns:
330            str: The extracted access token.
331        """
332        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
333
334    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
335        """
336        Extracts the refresh token from the given response data.
337
338        Args:
339            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
340
341        Returns:
342            str: The extracted refresh token.
343        """
344        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
345
346    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
347        """
348        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
349
350        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
351
352        Args:
353            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
354
355        Returns:
356            The extracted token_expiry_date or None if not found.
357        """
358        expires_in = self._find_and_get_value_from_response(
359            response_data, self.get_expires_in_name()
360        )
361        if expires_in is not None:
362            return self._parse_token_expiration_date(expires_in)
363
364        # expires_in is None
365        existing_expiry_date = self.get_token_expiry_date()
366        if existing_expiry_date and not self.token_has_expired():
367            return existing_expiry_date
368
369        return self._default_token_expiry_date()
370
371    def _find_and_get_value_from_response(
372        self,
373        response_data: Mapping[str, Any],
374        key_name: str,
375        max_depth: int = 5,
376        current_depth: int = 0,
377    ) -> Any:
378        """
379        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
380
381        Args:
382            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
383            key_name (str): The key to search for in the response data.
384            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
385            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
386
387        Returns:
388            Any: The value associated with the specified key if found, otherwise None.
389
390        Raises:
391            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
392        """
393        if current_depth > max_depth:
394            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
395            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
396            raise ResponseKeysMaxRecurtionReached(
397                internal_message=message, message=message, failure_type=FailureType.config_error
398            )
399
400        if isinstance(response_data, dict):
401            # get from the root level
402            if key_name in response_data:
403                return response_data[key_name]
404
405            # get from the nested object
406            for _, value in response_data.items():
407                result = self._find_and_get_value_from_response(
408                    value, key_name, max_depth, current_depth + 1
409                )
410                if result is not None:
411                    return result
412
413        # get from the nested array object
414        elif isinstance(response_data, list):
415            for item in response_data:
416                result = self._find_and_get_value_from_response(
417                    item, key_name, max_depth, current_depth + 1
418                )
419                if result is not None:
420                    return result
421
422        return None
423
424    @property
425    def _message_repository(self) -> Optional[MessageRepository]:
426        """
427        The implementation can define a message_repository if it wants debugging logs for HTTP requests
428        """
429        return _NOOP_MESSAGE_REPOSITORY
430
431    def _log_response(self, response: requests.Response) -> None:
432        """
433        Logs the HTTP response using the message repository if it is available.
434
435        Args:
436            response (requests.Response): The HTTP response to log.
437        """
438        if self._message_repository:
439            self._message_repository.log_message(
440                Level.DEBUG,
441                lambda: format_http_message(
442                    response,
443                    "Refresh token",
444                    "Obtains access token",
445                    self._NO_STREAM_NAME,
446                    is_auxiliary=True,
447                    type="AUTH",
448                ),
449            )
450
451    # ----------------
452    # ABSTR METHODS
453    # ----------------
454
455    @abstractmethod
456    def get_token_refresh_endpoint(self) -> Optional[str]:
457        """Returns the endpoint to refresh the access token"""
458
459    @abstractmethod
460    def get_client_id_name(self) -> str:
461        """The client id name to authenticate"""
462
463    @abstractmethod
464    def get_client_id(self) -> str:
465        """The client id to authenticate"""
466
467    @abstractmethod
468    def get_client_secret_name(self) -> str:
469        """The client secret name to authenticate"""
470
471    @abstractmethod
472    def get_client_secret(self) -> str:
473        """The client secret to authenticate"""
474
475    @abstractmethod
476    def get_refresh_token_name(self) -> str:
477        """The refresh token name to authenticate"""
478
479    @abstractmethod
480    def get_refresh_token(self) -> Optional[str]:
481        """The token used to refresh the access token when it expires"""
482
483    @abstractmethod
484    def get_scopes(self) -> List[str]:
485        """List of requested scopes"""
486
487    @abstractmethod
488    def get_token_expiry_date(self) -> AirbyteDateTime:
489        """Expiration date of the access token"""
490
491    @abstractmethod
492    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
493        """Setter for access token expiration date"""
494
495    @abstractmethod
496    def get_access_token_name(self) -> str:
497        """Field to extract access token from in the response"""
498
499    @abstractmethod
500    def get_expires_in_name(self) -> str:
501        """Returns the expires_in field name"""
502
503    @abstractmethod
504    def get_refresh_request_body(self) -> Mapping[str, Any]:
505        """Returns the request body to set on the refresh request"""
506
507    @abstractmethod
508    def get_refresh_request_headers(self) -> Mapping[str, Any]:
509        """Returns the request headers to set on the refresh request"""
510
511    @abstractmethod
512    def get_grant_type(self) -> str:
513        """Returns grant_type specified for requesting access_token"""
514
515    @abstractmethod
516    def get_grant_type_name(self) -> str:
517        """Returns grant_type specified name for requesting access_token"""
518
519    @property
520    @abstractmethod
521    def access_token(self) -> str:
522        """Returns the access token"""
523
524    @access_token.setter
525    @abstractmethod
526    def access_token(self, value: str) -> str:
527        """Setter for the access token"""
logger = <Logger airbyte (INFO)>
class ResponseKeysMaxRecurtionReached(airbyte_cdk.utils.traced_exception.AirbyteTracedException):
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
30    """
31    Raised when the max level of recursion is reached, when trying to
32    find-and-get the target key, during the `_make_handled_request`
33    """

Raised when the max level of recursion is reached, when trying to find-and-get the target key, during the _make_handled_request

class AbstractOauth2Authenticator(requests.auth.AuthBase):
 36class AbstractOauth2Authenticator(AuthBase):
 37    """
 38    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 39    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 40    delegating that behavior to the classes implementing the interface.
 41    """
 42
 43    _NO_STREAM_NAME = None
 44
 45    def __init__(
 46        self,
 47        refresh_token_error_status_codes: Tuple[int, ...] = (),
 48        refresh_token_error_key: str = "",
 49        refresh_token_error_values: Tuple[str, ...] = (),
 50    ) -> None:
 51        """
 52        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 53        then http errors with such params will be wrapped in AirbyteTracedException.
 54        """
 55        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 56        self._refresh_token_error_key = refresh_token_error_key
 57        self._refresh_token_error_values = refresh_token_error_values
 58
 59    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 60        """Attach the HTTP headers required to authenticate on the HTTP request"""
 61        request.headers.update(self.get_auth_header())
 62        return request
 63
 64    @property
 65    def _is_access_token_flow(self) -> bool:
 66        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 67
 68    @property
 69    def token_expiry_is_time_of_expiration(self) -> bool:
 70        """
 71        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 72        """
 73
 74        return False
 75
 76    @property
 77    def token_expiry_date_format(self) -> Optional[str]:
 78        """
 79        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 80        """
 81
 82        return None
 83
 84    def get_auth_header(self) -> Mapping[str, Any]:
 85        """HTTP header to set on the requests"""
 86        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 87        return {"Authorization": f"Bearer {token}"}
 88
 89    def get_access_token(self) -> str:
 90        """Returns the access token"""
 91        if self.token_has_expired():
 92            token, expires_in = self.refresh_access_token()
 93            self.access_token = token
 94            self.set_token_expiry_date(expires_in)
 95
 96        return self.access_token
 97
 98    def token_has_expired(self) -> bool:
 99        """Returns True if the token is expired"""
100        return ab_datetime_now() > self.get_token_expiry_date()
101
102    def build_refresh_request_body(self) -> Mapping[str, Any]:
103        """
104        Returns the request body to set on the refresh request.
105
106        Override to define additional parameters.
107
108        Client credentials (client_id and client_secret) are excluded from the body when
109        refresh_request_headers contains an Authorization header (e.g., Basic auth).
110        This is required by OAuth providers like Gong that expect credentials ONLY in the
111        Authorization header and reject requests that include them in both places.
112        """
113        # Check if credentials are being sent via Authorization header
114        headers = self.get_refresh_request_headers()
115        credentials_in_header = headers and "Authorization" in headers
116
117        # Only include client credentials in body if not already in header
118        include_client_credentials = not credentials_in_header
119
120        payload: MutableMapping[str, Any] = {
121            self.get_grant_type_name(): self.get_grant_type(),
122        }
123
124        # Only include client credentials in body if configured to do so and not in header
125        if include_client_credentials:
126            payload[self.get_client_id_name()] = self.get_client_id()
127            payload[self.get_client_secret_name()] = self.get_client_secret()
128
129        payload[self.get_refresh_token_name()] = self.get_refresh_token()
130
131        if self.get_scopes():
132            payload["scopes"] = self.get_scopes()
133
134        if self.get_refresh_request_body():
135            for key, val in self.get_refresh_request_body().items():
136                # We defer to existing oauth constructs over custom configured fields
137                if key not in payload:
138                    payload[key] = val
139
140        return payload
141
142    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
143        """
144        Returns the request headers to set on the refresh request
145
146        """
147        headers = self.get_refresh_request_headers()
148        return headers if headers else None
149
150    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
151        """
152        Returns the refresh token and its expiration datetime
153
154        :return: a tuple of (access_token, token_lifespan)
155        """
156        response_json = self._make_handled_request()
157        self._ensure_access_token_in_response(response_json)
158
159        return (
160            self._extract_access_token(response_json),
161            self._extract_token_expiry_date(response_json),
162        )
163
164    # ----------------
165    # PRIVATE METHODS
166    # ----------------
167
168    def _default_token_expiry_date(self) -> AirbyteDateTime:
169        """
170        Returns the default token expiry date
171        """
172        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
173        default_token_expiry_duration_hours = 1  # 1 hour
174        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
175
176    def _wrap_refresh_token_exception(
177        self, exception: requests.exceptions.RequestException
178    ) -> bool:
179        """
180        Wraps and handles exceptions that occur during the refresh token process.
181
182        This method checks if the provided exception is related to a refresh token error
183        by examining the response status code and specific error content.
184
185        Args:
186            exception (requests.exceptions.RequestException): The exception raised during the request.
187
188        Returns:
189            bool: True if the exception is related to a refresh token error, False otherwise.
190        """
191        try:
192            if exception.response is not None:
193                exception_content = exception.response.json()
194            else:
195                return False
196        except JSONDecodeError:
197            return False
198        return (
199            exception.response.status_code in self._refresh_token_error_status_codes
200            and exception_content.get(self._refresh_token_error_key)
201            in self._refresh_token_error_values
202        )
203
204    @backoff.on_exception(
205        backoff.expo,
206        DefaultBackoffException,
207        on_backoff=lambda details: logger.info(
208            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
209        ),
210        max_time=300,
211    )
212    def _make_handled_request(self) -> Any:
213        """
214        Makes a handled HTTP request to refresh an OAuth token.
215
216        This method sends a POST request to the token refresh endpoint with the necessary
217        headers and body to obtain a new access token. It handles various exceptions that
218        may occur during the request and logs the response for troubleshooting purposes.
219
220        Returns:
221            Mapping[str, Any]: The JSON response from the token refresh endpoint.
222
223        Raises:
224            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
225                                     or any 5xx server error.
226            AirbyteTracedException: If the refresh token is invalid or expired, prompting
227                                    re-authentication.
228            Exception: For any other exceptions that occur during the request.
229        """
230        try:
231            response = requests.request(
232                method="POST",
233                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
234                data=self.build_refresh_request_body(),
235                headers=self.build_refresh_request_headers(),
236            )
237
238            if not response.ok:
239                # log the response even if the request failed for troubleshooting purposes
240                self._log_response(response)
241                response.raise_for_status()
242
243            response_json = response.json()
244
245            try:
246                # extract the access token and add to secrets to avoid logging the raw value
247                access_key = self._extract_access_token(response_json)
248                if access_key:
249                    add_to_secrets(access_key)
250            except ResponseKeysMaxRecurtionReached as e:
251                # could not find the access token in the response, so do nothing
252                pass
253
254            self._log_response(response)
255
256            return response_json
257        except requests.exceptions.RequestException as e:
258            if e.response is not None:
259                if e.response.status_code == 429 or e.response.status_code >= 500:
260                    raise DefaultBackoffException(
261                        request=e.response.request,
262                        response=e.response,
263                        failure_type=FailureType.transient_error,
264                    )
265            if self._wrap_refresh_token_exception(e):
266                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
267                raise AirbyteTracedException(
268                    internal_message=message, message=message, failure_type=FailureType.config_error
269                )
270            raise
271        except Exception as e:
272            raise Exception(f"Error while refreshing access token: {e}") from e
273
274    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
275        """
276        Ensures that the access token is present in the response data.
277
278        This method attempts to extract the access token from the provided response data.
279        If the access token is not found, it raises an exception indicating that the token
280        refresh API response was missing the access token.
281
282        Args:
283            response_data (Mapping[str, Any]): The response data from which to extract the access token.
284
285        Raises:
286            Exception: If the access token is not found in the response data.
287            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
288        """
289        try:
290            access_key = self._extract_access_token(response_data)
291            if not access_key:
292                raise Exception(
293                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
294                )
295        except ResponseKeysMaxRecurtionReached as e:
296            raise e
297
298    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
299        """
300        Parse a string or integer token expiration date into a datetime object
301
302        :return: expiration datetime
303        """
304        if self.token_expiry_is_time_of_expiration:
305            if not self.token_expiry_date_format:
306                raise ValueError(
307                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
308                )
309            try:
310                return ab_datetime_parse(str(value))
311            except ValueError as e:
312                raise ValueError(f"Invalid token expiry date format: {e}")
313        else:
314            try:
315                # Only accept numeric values (as int/float/string) when no format specified
316                seconds = int(float(str(value)))
317                return ab_datetime_now() + timedelta(seconds=seconds)
318            except (ValueError, TypeError):
319                raise ValueError(
320                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
321                )
322
323    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
324        """
325        Extracts the access token from the given response data.
326
327        Args:
328            response_data (Mapping[str, Any]): The response data from which to extract the access token.
329
330        Returns:
331            str: The extracted access token.
332        """
333        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
334
335    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
336        """
337        Extracts the refresh token from the given response data.
338
339        Args:
340            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
341
342        Returns:
343            str: The extracted refresh token.
344        """
345        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
346
347    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
348        """
349        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
350
351        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
352
353        Args:
354            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
355
356        Returns:
357            The extracted token_expiry_date or None if not found.
358        """
359        expires_in = self._find_and_get_value_from_response(
360            response_data, self.get_expires_in_name()
361        )
362        if expires_in is not None:
363            return self._parse_token_expiration_date(expires_in)
364
365        # expires_in is None
366        existing_expiry_date = self.get_token_expiry_date()
367        if existing_expiry_date and not self.token_has_expired():
368            return existing_expiry_date
369
370        return self._default_token_expiry_date()
371
372    def _find_and_get_value_from_response(
373        self,
374        response_data: Mapping[str, Any],
375        key_name: str,
376        max_depth: int = 5,
377        current_depth: int = 0,
378    ) -> Any:
379        """
380        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
381
382        Args:
383            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
384            key_name (str): The key to search for in the response data.
385            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
386            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
387
388        Returns:
389            Any: The value associated with the specified key if found, otherwise None.
390
391        Raises:
392            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
393        """
394        if current_depth > max_depth:
395            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
396            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
397            raise ResponseKeysMaxRecurtionReached(
398                internal_message=message, message=message, failure_type=FailureType.config_error
399            )
400
401        if isinstance(response_data, dict):
402            # get from the root level
403            if key_name in response_data:
404                return response_data[key_name]
405
406            # get from the nested object
407            for _, value in response_data.items():
408                result = self._find_and_get_value_from_response(
409                    value, key_name, max_depth, current_depth + 1
410                )
411                if result is not None:
412                    return result
413
414        # get from the nested array object
415        elif isinstance(response_data, list):
416            for item in response_data:
417                result = self._find_and_get_value_from_response(
418                    item, key_name, max_depth, current_depth + 1
419                )
420                if result is not None:
421                    return result
422
423        return None
424
425    @property
426    def _message_repository(self) -> Optional[MessageRepository]:
427        """
428        The implementation can define a message_repository if it wants debugging logs for HTTP requests
429        """
430        return _NOOP_MESSAGE_REPOSITORY
431
432    def _log_response(self, response: requests.Response) -> None:
433        """
434        Logs the HTTP response using the message repository if it is available.
435
436        Args:
437            response (requests.Response): The HTTP response to log.
438        """
439        if self._message_repository:
440            self._message_repository.log_message(
441                Level.DEBUG,
442                lambda: format_http_message(
443                    response,
444                    "Refresh token",
445                    "Obtains access token",
446                    self._NO_STREAM_NAME,
447                    is_auxiliary=True,
448                    type="AUTH",
449                ),
450            )
451
452    # ----------------
453    # ABSTR METHODS
454    # ----------------
455
456    @abstractmethod
457    def get_token_refresh_endpoint(self) -> Optional[str]:
458        """Returns the endpoint to refresh the access token"""
459
460    @abstractmethod
461    def get_client_id_name(self) -> str:
462        """The client id name to authenticate"""
463
464    @abstractmethod
465    def get_client_id(self) -> str:
466        """The client id to authenticate"""
467
468    @abstractmethod
469    def get_client_secret_name(self) -> str:
470        """The client secret name to authenticate"""
471
472    @abstractmethod
473    def get_client_secret(self) -> str:
474        """The client secret to authenticate"""
475
476    @abstractmethod
477    def get_refresh_token_name(self) -> str:
478        """The refresh token name to authenticate"""
479
480    @abstractmethod
481    def get_refresh_token(self) -> Optional[str]:
482        """The token used to refresh the access token when it expires"""
483
484    @abstractmethod
485    def get_scopes(self) -> List[str]:
486        """List of requested scopes"""
487
488    @abstractmethod
489    def get_token_expiry_date(self) -> AirbyteDateTime:
490        """Expiration date of the access token"""
491
492    @abstractmethod
493    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
494        """Setter for access token expiration date"""
495
496    @abstractmethod
497    def get_access_token_name(self) -> str:
498        """Field to extract access token from in the response"""
499
500    @abstractmethod
501    def get_expires_in_name(self) -> str:
502        """Returns the expires_in field name"""
503
504    @abstractmethod
505    def get_refresh_request_body(self) -> Mapping[str, Any]:
506        """Returns the request body to set on the refresh request"""
507
508    @abstractmethod
509    def get_refresh_request_headers(self) -> Mapping[str, Any]:
510        """Returns the request headers to set on the refresh request"""
511
512    @abstractmethod
513    def get_grant_type(self) -> str:
514        """Returns grant_type specified for requesting access_token"""
515
516    @abstractmethod
517    def get_grant_type_name(self) -> str:
518        """Returns grant_type specified name for requesting access_token"""
519
520    @property
521    @abstractmethod
522    def access_token(self) -> str:
523        """Returns the access token"""
524
525    @access_token.setter
526    @abstractmethod
527    def access_token(self, value: str) -> str:
528        """Setter for the access token"""

Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.

AbstractOauth2Authenticator( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = '', refresh_token_error_values: Tuple[str, ...] = ())
45    def __init__(
46        self,
47        refresh_token_error_status_codes: Tuple[int, ...] = (),
48        refresh_token_error_key: str = "",
49        refresh_token_error_values: Tuple[str, ...] = (),
50    ) -> None:
51        """
52        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
53        then http errors with such params will be wrapped in AirbyteTracedException.
54        """
55        self._refresh_token_error_status_codes = refresh_token_error_status_codes
56        self._refresh_token_error_key = refresh_token_error_key
57        self._refresh_token_error_values = refresh_token_error_values

If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.

token_expiry_is_time_of_expiration: bool
68    @property
69    def token_expiry_is_time_of_expiration(self) -> bool:
70        """
71        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
72        """
73
74        return False

Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.

token_expiry_date_format: Optional[str]
76    @property
77    def token_expiry_date_format(self) -> Optional[str]:
78        """
79        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
80        """
81
82        return None

Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires

def get_auth_header(self) -> Mapping[str, Any]:
84    def get_auth_header(self) -> Mapping[str, Any]:
85        """HTTP header to set on the requests"""
86        token = self.access_token if self._is_access_token_flow else self.get_access_token()
87        return {"Authorization": f"Bearer {token}"}

HTTP header to set on the requests

def get_access_token(self) -> str:
89    def get_access_token(self) -> str:
90        """Returns the access token"""
91        if self.token_has_expired():
92            token, expires_in = self.refresh_access_token()
93            self.access_token = token
94            self.set_token_expiry_date(expires_in)
95
96        return self.access_token

Returns the access token

def token_has_expired(self) -> bool:
 98    def token_has_expired(self) -> bool:
 99        """Returns True if the token is expired"""
100        return ab_datetime_now() > self.get_token_expiry_date()

Returns True if the token is expired

def build_refresh_request_body(self) -> Mapping[str, Any]:
102    def build_refresh_request_body(self) -> Mapping[str, Any]:
103        """
104        Returns the request body to set on the refresh request.
105
106        Override to define additional parameters.
107
108        Client credentials (client_id and client_secret) are excluded from the body when
109        refresh_request_headers contains an Authorization header (e.g., Basic auth).
110        This is required by OAuth providers like Gong that expect credentials ONLY in the
111        Authorization header and reject requests that include them in both places.
112        """
113        # Check if credentials are being sent via Authorization header
114        headers = self.get_refresh_request_headers()
115        credentials_in_header = headers and "Authorization" in headers
116
117        # Only include client credentials in body if not already in header
118        include_client_credentials = not credentials_in_header
119
120        payload: MutableMapping[str, Any] = {
121            self.get_grant_type_name(): self.get_grant_type(),
122        }
123
124        # Only include client credentials in body if configured to do so and not in header
125        if include_client_credentials:
126            payload[self.get_client_id_name()] = self.get_client_id()
127            payload[self.get_client_secret_name()] = self.get_client_secret()
128
129        payload[self.get_refresh_token_name()] = self.get_refresh_token()
130
131        if self.get_scopes():
132            payload["scopes"] = self.get_scopes()
133
134        if self.get_refresh_request_body():
135            for key, val in self.get_refresh_request_body().items():
136                # We defer to existing oauth constructs over custom configured fields
137                if key not in payload:
138                    payload[key] = val
139
140        return payload

Returns the request body to set on the refresh request.

Override to define additional parameters.

Client credentials (client_id and client_secret) are excluded from the body when refresh_request_headers contains an Authorization header (e.g., Basic auth). This is required by OAuth providers like Gong that expect credentials ONLY in the Authorization header and reject requests that include them in both places.

def build_refresh_request_headers(self) -> Optional[Mapping[str, Any]]:
142    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
143        """
144        Returns the request headers to set on the refresh request
145
146        """
147        headers = self.get_refresh_request_headers()
148        return headers if headers else None

Returns the request headers to set on the refresh request

def refresh_access_token(self) -> Tuple[str, airbyte_cdk.utils.datetime_helpers.AirbyteDateTime]:
150    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
151        """
152        Returns the refresh token and its expiration datetime
153
154        :return: a tuple of (access_token, token_lifespan)
155        """
156        response_json = self._make_handled_request()
157        self._ensure_access_token_in_response(response_json)
158
159        return (
160            self._extract_access_token(response_json),
161            self._extract_token_expiry_date(response_json),
162        )

Returns the refresh token and its expiration datetime

Returns

a tuple of (access_token, token_lifespan)

@abstractmethod
def get_token_refresh_endpoint(self) -> Optional[str]:
456    @abstractmethod
457    def get_token_refresh_endpoint(self) -> Optional[str]:
458        """Returns the endpoint to refresh the access token"""

Returns the endpoint to refresh the access token

@abstractmethod
def get_client_id_name(self) -> str:
460    @abstractmethod
461    def get_client_id_name(self) -> str:
462        """The client id name to authenticate"""

The client id name to authenticate

@abstractmethod
def get_client_id(self) -> str:
464    @abstractmethod
465    def get_client_id(self) -> str:
466        """The client id to authenticate"""

The client id to authenticate

@abstractmethod
def get_client_secret_name(self) -> str:
468    @abstractmethod
469    def get_client_secret_name(self) -> str:
470        """The client secret name to authenticate"""

The client secret name to authenticate

@abstractmethod
def get_client_secret(self) -> str:
472    @abstractmethod
473    def get_client_secret(self) -> str:
474        """The client secret to authenticate"""

The client secret to authenticate

@abstractmethod
def get_refresh_token_name(self) -> str:
476    @abstractmethod
477    def get_refresh_token_name(self) -> str:
478        """The refresh token name to authenticate"""

The refresh token name to authenticate

@abstractmethod
def get_refresh_token(self) -> Optional[str]:
480    @abstractmethod
481    def get_refresh_token(self) -> Optional[str]:
482        """The token used to refresh the access token when it expires"""

The token used to refresh the access token when it expires

@abstractmethod
def get_scopes(self) -> List[str]:
484    @abstractmethod
485    def get_scopes(self) -> List[str]:
486        """List of requested scopes"""

List of requested scopes

@abstractmethod
def get_token_expiry_date(self) -> airbyte_cdk.utils.datetime_helpers.AirbyteDateTime:
488    @abstractmethod
489    def get_token_expiry_date(self) -> AirbyteDateTime:
490        """Expiration date of the access token"""

Expiration date of the access token

@abstractmethod
def set_token_expiry_date(self, value: airbyte_cdk.utils.datetime_helpers.AirbyteDateTime) -> None:
492    @abstractmethod
493    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
494        """Setter for access token expiration date"""

Setter for access token expiration date

@abstractmethod
def get_access_token_name(self) -> str:
496    @abstractmethod
497    def get_access_token_name(self) -> str:
498        """Field to extract access token from in the response"""

Field to extract access token from in the response

@abstractmethod
def get_expires_in_name(self) -> str:
500    @abstractmethod
501    def get_expires_in_name(self) -> str:
502        """Returns the expires_in field name"""

Returns the expires_in field name

@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
504    @abstractmethod
505    def get_refresh_request_body(self) -> Mapping[str, Any]:
506        """Returns the request body to set on the refresh request"""

Returns the request body to set on the refresh request

@abstractmethod
def get_refresh_request_headers(self) -> Mapping[str, Any]:
508    @abstractmethod
509    def get_refresh_request_headers(self) -> Mapping[str, Any]:
510        """Returns the request headers to set on the refresh request"""

Returns the request headers to set on the refresh request

@abstractmethod
def get_grant_type(self) -> str:
512    @abstractmethod
513    def get_grant_type(self) -> str:
514        """Returns grant_type specified for requesting access_token"""

Returns grant_type specified for requesting access_token

@abstractmethod
def get_grant_type_name(self) -> str:
516    @abstractmethod
517    def get_grant_type_name(self) -> str:
518        """Returns grant_type specified name for requesting access_token"""

Returns grant_type specified name for requesting access_token

access_token: str
520    @property
521    @abstractmethod
522    def access_token(self) -> str:
523        """Returns the access token"""

Returns the access token