airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth

  1#
  2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
  3#
  4
  5import logging
  6import threading
  7from abc import abstractmethod
  8from datetime import timedelta
  9from json import JSONDecodeError
 10from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
 11
 12import backoff
 13import requests
 14from requests.auth import AuthBase
 15
 16from airbyte_cdk.models import FailureType, Level
 17from airbyte_cdk.sources.http_logger import format_http_message
 18from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
 19from airbyte_cdk.utils import AirbyteTracedException
 20from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets
 21from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse
 22
 23from ..exceptions import DefaultBackoffException
 24
 25logger = logging.getLogger("airbyte")
 26_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()
 27
 28
 29class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
 30    """
 31    Raised when the max level of recursion is reached, when trying to
 32    find-and-get the target key, during the `_make_handled_request`
 33    """
 34
 35
 36class AbstractOauth2Authenticator(AuthBase):
 37    """
 38    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 39    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 40    delegating that behavior to the classes implementing the interface.
 41    """
 42
 43    _NO_STREAM_NAME = None
 44
 45    # Class-level lock to prevent concurrent token refresh across multiple authenticator instances.
 46    # This is necessary because multiple streams may share the same OAuth credentials (refresh token)
 47    # through the connector config. Without this lock, concurrent refresh attempts can cause race
 48    # conditions where one stream successfully refreshes the token while others fail because the
 49    # refresh token has been invalidated (especially for single-use refresh tokens).
 50    _token_refresh_lock: threading.Lock = threading.Lock()
 51
 52    def __init__(
 53        self,
 54        refresh_token_error_status_codes: Tuple[int, ...] = (),
 55        refresh_token_error_key: str = "",
 56        refresh_token_error_values: Tuple[str, ...] = (),
 57    ) -> None:
 58        """
 59        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 60        then http errors with such params will be wrapped in AirbyteTracedException.
 61        """
 62        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 63        self._refresh_token_error_key = refresh_token_error_key
 64        self._refresh_token_error_values = refresh_token_error_values
 65
 66    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 67        """Attach the HTTP headers required to authenticate on the HTTP request"""
 68        request.headers.update(self.get_auth_header())
 69        return request
 70
 71    @property
 72    def _is_access_token_flow(self) -> bool:
 73        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 74
 75    @property
 76    def token_expiry_is_time_of_expiration(self) -> bool:
 77        """
 78        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 79        """
 80
 81        return False
 82
 83    @property
 84    def token_expiry_date_format(self) -> Optional[str]:
 85        """
 86        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 87        """
 88
 89        return None
 90
 91    def get_auth_header(self) -> Mapping[str, Any]:
 92        """HTTP header to set on the requests"""
 93        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 94        return {"Authorization": f"Bearer {token}"}
 95
 96    def get_access_token(self) -> str:
 97        """
 98        Returns the access token.
 99
100        This method uses double-checked locking to ensure thread-safe token refresh.
101        When multiple threads (streams) detect an expired token simultaneously, only one
102        will perform the refresh while others wait. After acquiring the lock, the token
103        expiry is re-checked to avoid redundant refresh attempts.
104        """
105        if self.token_has_expired():
106            with self._token_refresh_lock:
107                # Double-check after acquiring lock - another thread may have already refreshed
108                if self.token_has_expired():
109                    self.refresh_and_set_access_token()
110
111        return self.access_token
112
113    def refresh_and_set_access_token(self) -> None:
114        """Force refresh the access token and update internal state.
115
116        This method refreshes the access token regardless of whether it has expired,
117        and updates the internal token and expiry date. Subclasses may override this
118        to handle additional state updates (e.g., persisting new refresh tokens).
119        """
120        token, expires_in = self.refresh_access_token()
121        self.access_token = token
122        self.set_token_expiry_date(expires_in)
123
124    def token_has_expired(self) -> bool:
125        """Returns True if the token is expired"""
126        return ab_datetime_now() > self.get_token_expiry_date()
127
128    def build_refresh_request_body(self) -> Mapping[str, Any]:
129        """
130        Returns the request body to set on the refresh request.
131
132        Override to define additional parameters.
133
134        Client credentials (client_id and client_secret) are excluded from the body when
135        refresh_request_headers contains an Authorization header (e.g., Basic auth).
136        This is required by OAuth providers like Gong that expect credentials ONLY in the
137        Authorization header and reject requests that include them in both places.
138        """
139        # Check if credentials are being sent via Authorization header
140        headers = self.get_refresh_request_headers()
141        credentials_in_header = headers and "Authorization" in headers
142
143        # Only include client credentials in body if not already in header
144        include_client_credentials = not credentials_in_header
145
146        payload: MutableMapping[str, Any] = {
147            self.get_grant_type_name(): self.get_grant_type(),
148        }
149
150        # Only include client credentials in body if configured to do so and not in header
151        if include_client_credentials:
152            payload[self.get_client_id_name()] = self.get_client_id()
153            payload[self.get_client_secret_name()] = self.get_client_secret()
154
155        payload[self.get_refresh_token_name()] = self.get_refresh_token()
156
157        if self.get_scopes():
158            payload["scopes"] = self.get_scopes()
159
160        if self.get_refresh_request_body():
161            for key, val in self.get_refresh_request_body().items():
162                # We defer to existing oauth constructs over custom configured fields
163                if key not in payload:
164                    payload[key] = val
165
166        return payload
167
168    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
169        """
170        Returns the request headers to set on the refresh request
171
172        """
173        headers = self.get_refresh_request_headers()
174        return headers if headers else None
175
176    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
177        """
178        Returns the refresh token and its expiration datetime
179
180        :return: a tuple of (access_token, token_lifespan)
181        """
182        try:
183            response_json = self._make_handled_request()
184        except (
185            requests.exceptions.ConnectionError,
186            requests.exceptions.ConnectTimeout,
187            requests.exceptions.ReadTimeout,
188        ) as e:
189            raise AirbyteTracedException(
190                message="OAuth access token refresh request failed due to a network error.",
191                internal_message=f"Network error during OAuth token refresh after retries were exhausted: {e}",
192                failure_type=FailureType.transient_error,
193            ) from e
194        self._ensure_access_token_in_response(response_json)
195
196        return (
197            self._extract_access_token(response_json),
198            self._extract_token_expiry_date(response_json),
199        )
200
201    # ----------------
202    # PRIVATE METHODS
203    # ----------------
204
205    def _default_token_expiry_date(self) -> AirbyteDateTime:
206        """
207        Returns the default token expiry date
208        """
209        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
210        default_token_expiry_duration_hours = 1  # 1 hour
211        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
212
213    def _wrap_refresh_token_exception(
214        self, exception: requests.exceptions.RequestException
215    ) -> bool:
216        """
217        Wraps and handles exceptions that occur during the refresh token process.
218
219        This method checks if the provided exception is related to a refresh token error
220        by examining the response status code and specific error content.
221
222        Args:
223            exception (requests.exceptions.RequestException): The exception raised during the request.
224
225        Returns:
226            bool: True if the exception is related to a refresh token error, False otherwise.
227        """
228        try:
229            if exception.response is not None:
230                exception_content = exception.response.json()
231            else:
232                return False
233        except JSONDecodeError:
234            return False
235        return (
236            exception.response.status_code in self._refresh_token_error_status_codes
237            and exception_content.get(self._refresh_token_error_key)
238            in self._refresh_token_error_values
239        )
240
241    @backoff.on_exception(
242        backoff.expo,
243        (
244            DefaultBackoffException,
245            requests.exceptions.ConnectionError,
246            requests.exceptions.ConnectTimeout,
247            requests.exceptions.ReadTimeout,
248        ),
249        on_backoff=lambda details: logger.info(
250            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
251        ),
252        max_time=300,
253    )
254    def _make_handled_request(self) -> Any:
255        """
256        Makes a handled HTTP request to refresh an OAuth token.
257
258        This method sends a POST request to the token refresh endpoint with the necessary
259        headers and body to obtain a new access token. It handles various exceptions that
260        may occur during the request and logs the response for troubleshooting purposes.
261
262        Returns:
263            Mapping[str, Any]: The JSON response from the token refresh endpoint.
264
265        Raises:
266            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
267                                     or any 5xx server error.
268            AirbyteTracedException: If the refresh token is invalid or expired, prompting
269                                    re-authentication.
270            Exception: For any other exceptions that occur during the request.
271        """
272        try:
273            response = requests.request(
274                method="POST",
275                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
276                data=self.build_refresh_request_body(),
277                headers=self.build_refresh_request_headers(),
278            )
279
280            if not response.ok:
281                # log the response even if the request failed for troubleshooting purposes
282                self._log_response(response)
283                response.raise_for_status()
284
285            response_json = response.json()
286
287            try:
288                # extract the access token and add to secrets to avoid logging the raw value
289                access_key = self._extract_access_token(response_json)
290                if access_key:
291                    add_to_secrets(access_key)
292            except ResponseKeysMaxRecurtionReached as e:
293                # could not find the access token in the response, so do nothing
294                pass
295
296            self._log_response(response)
297
298            return response_json
299        except requests.exceptions.RequestException as e:
300            if e.response is not None:
301                if e.response.status_code == 429 or e.response.status_code >= 500:
302                    raise DefaultBackoffException(
303                        request=e.response.request,
304                        response=e.response,
305                        failure_type=FailureType.transient_error,
306                    )
307            if self._wrap_refresh_token_exception(e):
308                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
309                raise AirbyteTracedException(
310                    internal_message=message, message=message, failure_type=FailureType.config_error
311                )
312            raise
313        except Exception as e:
314            raise AirbyteTracedException(
315                message="OAuth access token refresh request failed.",
316                internal_message=f"Unexpected error during OAuth token refresh: {e}",
317                failure_type=FailureType.system_error,
318            ) from e
319
320    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
321        """
322        Ensures that the access token is present in the response data.
323
324        This method attempts to extract the access token from the provided response data.
325        If the access token is not found, it raises an exception indicating that the token
326        refresh API response was missing the access token.
327
328        Args:
329            response_data (Mapping[str, Any]): The response data from which to extract the access token.
330
331        Raises:
332            Exception: If the access token is not found in the response data.
333            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
334        """
335        try:
336            access_key = self._extract_access_token(response_data)
337            if not access_key:
338                raise Exception(
339                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
340                )
341        except ResponseKeysMaxRecurtionReached as e:
342            raise e
343
344    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
345        """
346        Parse a string or integer token expiration date into a datetime object
347
348        :return: expiration datetime
349        """
350        if self.token_expiry_is_time_of_expiration:
351            if not self.token_expiry_date_format:
352                raise ValueError(
353                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
354                )
355            try:
356                return ab_datetime_parse(str(value))
357            except ValueError as e:
358                raise ValueError(f"Invalid token expiry date format: {e}")
359        else:
360            try:
361                # Only accept numeric values (as int/float/string) when no format specified
362                seconds = int(float(str(value)))
363                return ab_datetime_now() + timedelta(seconds=seconds)
364            except (ValueError, TypeError):
365                raise ValueError(
366                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
367                )
368
369    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
370        """
371        Extracts the access token from the given response data.
372
373        Args:
374            response_data (Mapping[str, Any]): The response data from which to extract the access token.
375
376        Returns:
377            str: The extracted access token.
378        """
379        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
380
381    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
382        """
383        Extracts the refresh token from the given response data.
384
385        Args:
386            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
387
388        Returns:
389            str: The extracted refresh token.
390        """
391        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
392
393    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
394        """
395        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
396
397        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
398
399        Args:
400            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
401
402        Returns:
403            The extracted token_expiry_date or None if not found.
404        """
405        expires_in = self._find_and_get_value_from_response(
406            response_data, self.get_expires_in_name()
407        )
408        if expires_in is not None:
409            return self._parse_token_expiration_date(expires_in)
410
411        # expires_in is None
412        existing_expiry_date = self.get_token_expiry_date()
413        if existing_expiry_date and not self.token_has_expired():
414            return existing_expiry_date
415
416        return self._default_token_expiry_date()
417
418    def _find_and_get_value_from_response(
419        self,
420        response_data: Mapping[str, Any],
421        key_name: str,
422        max_depth: int = 5,
423        current_depth: int = 0,
424    ) -> Any:
425        """
426        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
427
428        Args:
429            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
430            key_name (str): The key to search for in the response data.
431            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
432            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
433
434        Returns:
435            Any: The value associated with the specified key if found, otherwise None.
436
437        Raises:
438            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
439        """
440        if current_depth > max_depth:
441            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
442            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
443            raise ResponseKeysMaxRecurtionReached(
444                internal_message=message, message=message, failure_type=FailureType.config_error
445            )
446
447        if isinstance(response_data, dict):
448            # get from the root level
449            if key_name in response_data:
450                return response_data[key_name]
451
452            # get from the nested object
453            for _, value in response_data.items():
454                result = self._find_and_get_value_from_response(
455                    value, key_name, max_depth, current_depth + 1
456                )
457                if result is not None:
458                    return result
459
460        # get from the nested array object
461        elif isinstance(response_data, list):
462            for item in response_data:
463                result = self._find_and_get_value_from_response(
464                    item, key_name, max_depth, current_depth + 1
465                )
466                if result is not None:
467                    return result
468
469        return None
470
471    @property
472    def _message_repository(self) -> Optional[MessageRepository]:
473        """
474        The implementation can define a message_repository if it wants debugging logs for HTTP requests
475        """
476        return _NOOP_MESSAGE_REPOSITORY
477
478    def _log_response(self, response: requests.Response) -> None:
479        """
480        Logs the HTTP response using the message repository if it is available.
481
482        Args:
483            response (requests.Response): The HTTP response to log.
484        """
485        if self._message_repository:
486            self._message_repository.log_message(
487                Level.DEBUG,
488                lambda: format_http_message(
489                    response,
490                    "Refresh token",
491                    "Obtains access token",
492                    self._NO_STREAM_NAME,
493                    is_auxiliary=True,
494                    type="AUTH",
495                ),
496            )
497
498    # ----------------
499    # ABSTR METHODS
500    # ----------------
501
502    @abstractmethod
503    def get_token_refresh_endpoint(self) -> Optional[str]:
504        """Returns the endpoint to refresh the access token"""
505
506    @abstractmethod
507    def get_client_id_name(self) -> str:
508        """The client id name to authenticate"""
509
510    @abstractmethod
511    def get_client_id(self) -> str:
512        """The client id to authenticate"""
513
514    @abstractmethod
515    def get_client_secret_name(self) -> str:
516        """The client secret name to authenticate"""
517
518    @abstractmethod
519    def get_client_secret(self) -> str:
520        """The client secret to authenticate"""
521
522    @abstractmethod
523    def get_refresh_token_name(self) -> str:
524        """The refresh token name to authenticate"""
525
526    @abstractmethod
527    def get_refresh_token(self) -> Optional[str]:
528        """The token used to refresh the access token when it expires"""
529
530    @abstractmethod
531    def get_scopes(self) -> List[str]:
532        """List of requested scopes"""
533
534    @abstractmethod
535    def get_token_expiry_date(self) -> AirbyteDateTime:
536        """Expiration date of the access token"""
537
538    @abstractmethod
539    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
540        """Setter for access token expiration date"""
541
542    @abstractmethod
543    def get_access_token_name(self) -> str:
544        """Field to extract access token from in the response"""
545
546    @abstractmethod
547    def get_expires_in_name(self) -> str:
548        """Returns the expires_in field name"""
549
550    @abstractmethod
551    def get_refresh_request_body(self) -> Mapping[str, Any]:
552        """Returns the request body to set on the refresh request"""
553
554    @abstractmethod
555    def get_refresh_request_headers(self) -> Mapping[str, Any]:
556        """Returns the request headers to set on the refresh request"""
557
558    @abstractmethod
559    def get_grant_type(self) -> str:
560        """Returns grant_type specified for requesting access_token"""
561
562    @abstractmethod
563    def get_grant_type_name(self) -> str:
564        """Returns grant_type specified name for requesting access_token"""
565
566    @property
567    @abstractmethod
568    def access_token(self) -> str:
569        """Returns the access token"""
570
571    @access_token.setter
572    @abstractmethod
573    def access_token(self, value: str) -> str:
574        """Setter for the access token"""
logger = <Logger airbyte (INFO)>
class ResponseKeysMaxRecurtionReached(airbyte_cdk.utils.traced_exception.AirbyteTracedException):
30class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
31    """
32    Raised when the max level of recursion is reached, when trying to
33    find-and-get the target key, during the `_make_handled_request`
34    """

Raised when the max level of recursion is reached, when trying to find-and-get the target key, during the _make_handled_request

class AbstractOauth2Authenticator(requests.auth.AuthBase):
 37class AbstractOauth2Authenticator(AuthBase):
 38    """
 39    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 40    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 41    delegating that behavior to the classes implementing the interface.
 42    """
 43
 44    _NO_STREAM_NAME = None
 45
 46    # Class-level lock to prevent concurrent token refresh across multiple authenticator instances.
 47    # This is necessary because multiple streams may share the same OAuth credentials (refresh token)
 48    # through the connector config. Without this lock, concurrent refresh attempts can cause race
 49    # conditions where one stream successfully refreshes the token while others fail because the
 50    # refresh token has been invalidated (especially for single-use refresh tokens).
 51    _token_refresh_lock: threading.Lock = threading.Lock()
 52
 53    def __init__(
 54        self,
 55        refresh_token_error_status_codes: Tuple[int, ...] = (),
 56        refresh_token_error_key: str = "",
 57        refresh_token_error_values: Tuple[str, ...] = (),
 58    ) -> None:
 59        """
 60        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 61        then http errors with such params will be wrapped in AirbyteTracedException.
 62        """
 63        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 64        self._refresh_token_error_key = refresh_token_error_key
 65        self._refresh_token_error_values = refresh_token_error_values
 66
 67    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 68        """Attach the HTTP headers required to authenticate on the HTTP request"""
 69        request.headers.update(self.get_auth_header())
 70        return request
 71
 72    @property
 73    def _is_access_token_flow(self) -> bool:
 74        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 75
 76    @property
 77    def token_expiry_is_time_of_expiration(self) -> bool:
 78        """
 79        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 80        """
 81
 82        return False
 83
 84    @property
 85    def token_expiry_date_format(self) -> Optional[str]:
 86        """
 87        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 88        """
 89
 90        return None
 91
 92    def get_auth_header(self) -> Mapping[str, Any]:
 93        """HTTP header to set on the requests"""
 94        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 95        return {"Authorization": f"Bearer {token}"}
 96
 97    def get_access_token(self) -> str:
 98        """
 99        Returns the access token.
100
101        This method uses double-checked locking to ensure thread-safe token refresh.
102        When multiple threads (streams) detect an expired token simultaneously, only one
103        will perform the refresh while others wait. After acquiring the lock, the token
104        expiry is re-checked to avoid redundant refresh attempts.
105        """
106        if self.token_has_expired():
107            with self._token_refresh_lock:
108                # Double-check after acquiring lock - another thread may have already refreshed
109                if self.token_has_expired():
110                    self.refresh_and_set_access_token()
111
112        return self.access_token
113
114    def refresh_and_set_access_token(self) -> None:
115        """Force refresh the access token and update internal state.
116
117        This method refreshes the access token regardless of whether it has expired,
118        and updates the internal token and expiry date. Subclasses may override this
119        to handle additional state updates (e.g., persisting new refresh tokens).
120        """
121        token, expires_in = self.refresh_access_token()
122        self.access_token = token
123        self.set_token_expiry_date(expires_in)
124
125    def token_has_expired(self) -> bool:
126        """Returns True if the token is expired"""
127        return ab_datetime_now() > self.get_token_expiry_date()
128
129    def build_refresh_request_body(self) -> Mapping[str, Any]:
130        """
131        Returns the request body to set on the refresh request.
132
133        Override to define additional parameters.
134
135        Client credentials (client_id and client_secret) are excluded from the body when
136        refresh_request_headers contains an Authorization header (e.g., Basic auth).
137        This is required by OAuth providers like Gong that expect credentials ONLY in the
138        Authorization header and reject requests that include them in both places.
139        """
140        # Check if credentials are being sent via Authorization header
141        headers = self.get_refresh_request_headers()
142        credentials_in_header = headers and "Authorization" in headers
143
144        # Only include client credentials in body if not already in header
145        include_client_credentials = not credentials_in_header
146
147        payload: MutableMapping[str, Any] = {
148            self.get_grant_type_name(): self.get_grant_type(),
149        }
150
151        # Only include client credentials in body if configured to do so and not in header
152        if include_client_credentials:
153            payload[self.get_client_id_name()] = self.get_client_id()
154            payload[self.get_client_secret_name()] = self.get_client_secret()
155
156        payload[self.get_refresh_token_name()] = self.get_refresh_token()
157
158        if self.get_scopes():
159            payload["scopes"] = self.get_scopes()
160
161        if self.get_refresh_request_body():
162            for key, val in self.get_refresh_request_body().items():
163                # We defer to existing oauth constructs over custom configured fields
164                if key not in payload:
165                    payload[key] = val
166
167        return payload
168
169    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
170        """
171        Returns the request headers to set on the refresh request
172
173        """
174        headers = self.get_refresh_request_headers()
175        return headers if headers else None
176
177    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
178        """
179        Returns the refresh token and its expiration datetime
180
181        :return: a tuple of (access_token, token_lifespan)
182        """
183        try:
184            response_json = self._make_handled_request()
185        except (
186            requests.exceptions.ConnectionError,
187            requests.exceptions.ConnectTimeout,
188            requests.exceptions.ReadTimeout,
189        ) as e:
190            raise AirbyteTracedException(
191                message="OAuth access token refresh request failed due to a network error.",
192                internal_message=f"Network error during OAuth token refresh after retries were exhausted: {e}",
193                failure_type=FailureType.transient_error,
194            ) from e
195        self._ensure_access_token_in_response(response_json)
196
197        return (
198            self._extract_access_token(response_json),
199            self._extract_token_expiry_date(response_json),
200        )
201
202    # ----------------
203    # PRIVATE METHODS
204    # ----------------
205
206    def _default_token_expiry_date(self) -> AirbyteDateTime:
207        """
208        Returns the default token expiry date
209        """
210        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
211        default_token_expiry_duration_hours = 1  # 1 hour
212        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
213
214    def _wrap_refresh_token_exception(
215        self, exception: requests.exceptions.RequestException
216    ) -> bool:
217        """
218        Wraps and handles exceptions that occur during the refresh token process.
219
220        This method checks if the provided exception is related to a refresh token error
221        by examining the response status code and specific error content.
222
223        Args:
224            exception (requests.exceptions.RequestException): The exception raised during the request.
225
226        Returns:
227            bool: True if the exception is related to a refresh token error, False otherwise.
228        """
229        try:
230            if exception.response is not None:
231                exception_content = exception.response.json()
232            else:
233                return False
234        except JSONDecodeError:
235            return False
236        return (
237            exception.response.status_code in self._refresh_token_error_status_codes
238            and exception_content.get(self._refresh_token_error_key)
239            in self._refresh_token_error_values
240        )
241
242    @backoff.on_exception(
243        backoff.expo,
244        (
245            DefaultBackoffException,
246            requests.exceptions.ConnectionError,
247            requests.exceptions.ConnectTimeout,
248            requests.exceptions.ReadTimeout,
249        ),
250        on_backoff=lambda details: logger.info(
251            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
252        ),
253        max_time=300,
254    )
255    def _make_handled_request(self) -> Any:
256        """
257        Makes a handled HTTP request to refresh an OAuth token.
258
259        This method sends a POST request to the token refresh endpoint with the necessary
260        headers and body to obtain a new access token. It handles various exceptions that
261        may occur during the request and logs the response for troubleshooting purposes.
262
263        Returns:
264            Mapping[str, Any]: The JSON response from the token refresh endpoint.
265
266        Raises:
267            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
268                                     or any 5xx server error.
269            AirbyteTracedException: If the refresh token is invalid or expired, prompting
270                                    re-authentication.
271            Exception: For any other exceptions that occur during the request.
272        """
273        try:
274            response = requests.request(
275                method="POST",
276                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
277                data=self.build_refresh_request_body(),
278                headers=self.build_refresh_request_headers(),
279            )
280
281            if not response.ok:
282                # log the response even if the request failed for troubleshooting purposes
283                self._log_response(response)
284                response.raise_for_status()
285
286            response_json = response.json()
287
288            try:
289                # extract the access token and add to secrets to avoid logging the raw value
290                access_key = self._extract_access_token(response_json)
291                if access_key:
292                    add_to_secrets(access_key)
293            except ResponseKeysMaxRecurtionReached as e:
294                # could not find the access token in the response, so do nothing
295                pass
296
297            self._log_response(response)
298
299            return response_json
300        except requests.exceptions.RequestException as e:
301            if e.response is not None:
302                if e.response.status_code == 429 or e.response.status_code >= 500:
303                    raise DefaultBackoffException(
304                        request=e.response.request,
305                        response=e.response,
306                        failure_type=FailureType.transient_error,
307                    )
308            if self._wrap_refresh_token_exception(e):
309                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
310                raise AirbyteTracedException(
311                    internal_message=message, message=message, failure_type=FailureType.config_error
312                )
313            raise
314        except Exception as e:
315            raise AirbyteTracedException(
316                message="OAuth access token refresh request failed.",
317                internal_message=f"Unexpected error during OAuth token refresh: {e}",
318                failure_type=FailureType.system_error,
319            ) from e
320
321    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
322        """
323        Ensures that the access token is present in the response data.
324
325        This method attempts to extract the access token from the provided response data.
326        If the access token is not found, it raises an exception indicating that the token
327        refresh API response was missing the access token.
328
329        Args:
330            response_data (Mapping[str, Any]): The response data from which to extract the access token.
331
332        Raises:
333            Exception: If the access token is not found in the response data.
334            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
335        """
336        try:
337            access_key = self._extract_access_token(response_data)
338            if not access_key:
339                raise Exception(
340                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
341                )
342        except ResponseKeysMaxRecurtionReached as e:
343            raise e
344
345    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
346        """
347        Parse a string or integer token expiration date into a datetime object
348
349        :return: expiration datetime
350        """
351        if self.token_expiry_is_time_of_expiration:
352            if not self.token_expiry_date_format:
353                raise ValueError(
354                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
355                )
356            try:
357                return ab_datetime_parse(str(value))
358            except ValueError as e:
359                raise ValueError(f"Invalid token expiry date format: {e}")
360        else:
361            try:
362                # Only accept numeric values (as int/float/string) when no format specified
363                seconds = int(float(str(value)))
364                return ab_datetime_now() + timedelta(seconds=seconds)
365            except (ValueError, TypeError):
366                raise ValueError(
367                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
368                )
369
370    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
371        """
372        Extracts the access token from the given response data.
373
374        Args:
375            response_data (Mapping[str, Any]): The response data from which to extract the access token.
376
377        Returns:
378            str: The extracted access token.
379        """
380        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
381
382    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
383        """
384        Extracts the refresh token from the given response data.
385
386        Args:
387            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
388
389        Returns:
390            str: The extracted refresh token.
391        """
392        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
393
394    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
395        """
396        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
397
398        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
399
400        Args:
401            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
402
403        Returns:
404            The extracted token_expiry_date or None if not found.
405        """
406        expires_in = self._find_and_get_value_from_response(
407            response_data, self.get_expires_in_name()
408        )
409        if expires_in is not None:
410            return self._parse_token_expiration_date(expires_in)
411
412        # expires_in is None
413        existing_expiry_date = self.get_token_expiry_date()
414        if existing_expiry_date and not self.token_has_expired():
415            return existing_expiry_date
416
417        return self._default_token_expiry_date()
418
419    def _find_and_get_value_from_response(
420        self,
421        response_data: Mapping[str, Any],
422        key_name: str,
423        max_depth: int = 5,
424        current_depth: int = 0,
425    ) -> Any:
426        """
427        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
428
429        Args:
430            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
431            key_name (str): The key to search for in the response data.
432            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
433            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
434
435        Returns:
436            Any: The value associated with the specified key if found, otherwise None.
437
438        Raises:
439            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
440        """
441        if current_depth > max_depth:
442            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
443            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
444            raise ResponseKeysMaxRecurtionReached(
445                internal_message=message, message=message, failure_type=FailureType.config_error
446            )
447
448        if isinstance(response_data, dict):
449            # get from the root level
450            if key_name in response_data:
451                return response_data[key_name]
452
453            # get from the nested object
454            for _, value in response_data.items():
455                result = self._find_and_get_value_from_response(
456                    value, key_name, max_depth, current_depth + 1
457                )
458                if result is not None:
459                    return result
460
461        # get from the nested array object
462        elif isinstance(response_data, list):
463            for item in response_data:
464                result = self._find_and_get_value_from_response(
465                    item, key_name, max_depth, current_depth + 1
466                )
467                if result is not None:
468                    return result
469
470        return None
471
472    @property
473    def _message_repository(self) -> Optional[MessageRepository]:
474        """
475        The implementation can define a message_repository if it wants debugging logs for HTTP requests
476        """
477        return _NOOP_MESSAGE_REPOSITORY
478
479    def _log_response(self, response: requests.Response) -> None:
480        """
481        Logs the HTTP response using the message repository if it is available.
482
483        Args:
484            response (requests.Response): The HTTP response to log.
485        """
486        if self._message_repository:
487            self._message_repository.log_message(
488                Level.DEBUG,
489                lambda: format_http_message(
490                    response,
491                    "Refresh token",
492                    "Obtains access token",
493                    self._NO_STREAM_NAME,
494                    is_auxiliary=True,
495                    type="AUTH",
496                ),
497            )
498
499    # ----------------
500    # ABSTR METHODS
501    # ----------------
502
503    @abstractmethod
504    def get_token_refresh_endpoint(self) -> Optional[str]:
505        """Returns the endpoint to refresh the access token"""
506
507    @abstractmethod
508    def get_client_id_name(self) -> str:
509        """The client id name to authenticate"""
510
511    @abstractmethod
512    def get_client_id(self) -> str:
513        """The client id to authenticate"""
514
515    @abstractmethod
516    def get_client_secret_name(self) -> str:
517        """The client secret name to authenticate"""
518
519    @abstractmethod
520    def get_client_secret(self) -> str:
521        """The client secret to authenticate"""
522
523    @abstractmethod
524    def get_refresh_token_name(self) -> str:
525        """The refresh token name to authenticate"""
526
527    @abstractmethod
528    def get_refresh_token(self) -> Optional[str]:
529        """The token used to refresh the access token when it expires"""
530
531    @abstractmethod
532    def get_scopes(self) -> List[str]:
533        """List of requested scopes"""
534
535    @abstractmethod
536    def get_token_expiry_date(self) -> AirbyteDateTime:
537        """Expiration date of the access token"""
538
539    @abstractmethod
540    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
541        """Setter for access token expiration date"""
542
543    @abstractmethod
544    def get_access_token_name(self) -> str:
545        """Field to extract access token from in the response"""
546
547    @abstractmethod
548    def get_expires_in_name(self) -> str:
549        """Returns the expires_in field name"""
550
551    @abstractmethod
552    def get_refresh_request_body(self) -> Mapping[str, Any]:
553        """Returns the request body to set on the refresh request"""
554
555    @abstractmethod
556    def get_refresh_request_headers(self) -> Mapping[str, Any]:
557        """Returns the request headers to set on the refresh request"""
558
559    @abstractmethod
560    def get_grant_type(self) -> str:
561        """Returns grant_type specified for requesting access_token"""
562
563    @abstractmethod
564    def get_grant_type_name(self) -> str:
565        """Returns grant_type specified name for requesting access_token"""
566
567    @property
568    @abstractmethod
569    def access_token(self) -> str:
570        """Returns the access token"""
571
572    @access_token.setter
573    @abstractmethod
574    def access_token(self, value: str) -> str:
575        """Setter for the access token"""

Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.

AbstractOauth2Authenticator( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = '', refresh_token_error_values: Tuple[str, ...] = ())
53    def __init__(
54        self,
55        refresh_token_error_status_codes: Tuple[int, ...] = (),
56        refresh_token_error_key: str = "",
57        refresh_token_error_values: Tuple[str, ...] = (),
58    ) -> None:
59        """
60        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
61        then http errors with such params will be wrapped in AirbyteTracedException.
62        """
63        self._refresh_token_error_status_codes = refresh_token_error_status_codes
64        self._refresh_token_error_key = refresh_token_error_key
65        self._refresh_token_error_values = refresh_token_error_values

If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.

token_expiry_is_time_of_expiration: bool
76    @property
77    def token_expiry_is_time_of_expiration(self) -> bool:
78        """
79        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
80        """
81
82        return False

Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.

token_expiry_date_format: Optional[str]
84    @property
85    def token_expiry_date_format(self) -> Optional[str]:
86        """
87        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
88        """
89
90        return None

Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires

def get_auth_header(self) -> Mapping[str, Any]:
92    def get_auth_header(self) -> Mapping[str, Any]:
93        """HTTP header to set on the requests"""
94        token = self.access_token if self._is_access_token_flow else self.get_access_token()
95        return {"Authorization": f"Bearer {token}"}

HTTP header to set on the requests

def get_access_token(self) -> str:
 97    def get_access_token(self) -> str:
 98        """
 99        Returns the access token.
100
101        This method uses double-checked locking to ensure thread-safe token refresh.
102        When multiple threads (streams) detect an expired token simultaneously, only one
103        will perform the refresh while others wait. After acquiring the lock, the token
104        expiry is re-checked to avoid redundant refresh attempts.
105        """
106        if self.token_has_expired():
107            with self._token_refresh_lock:
108                # Double-check after acquiring lock - another thread may have already refreshed
109                if self.token_has_expired():
110                    self.refresh_and_set_access_token()
111
112        return self.access_token

Returns the access token.

This method uses double-checked locking to ensure thread-safe token refresh. When multiple threads (streams) detect an expired token simultaneously, only one will perform the refresh while others wait. After acquiring the lock, the token expiry is re-checked to avoid redundant refresh attempts.

def refresh_and_set_access_token(self) -> None:
114    def refresh_and_set_access_token(self) -> None:
115        """Force refresh the access token and update internal state.
116
117        This method refreshes the access token regardless of whether it has expired,
118        and updates the internal token and expiry date. Subclasses may override this
119        to handle additional state updates (e.g., persisting new refresh tokens).
120        """
121        token, expires_in = self.refresh_access_token()
122        self.access_token = token
123        self.set_token_expiry_date(expires_in)

Force refresh the access token and update internal state.

This method refreshes the access token regardless of whether it has expired, and updates the internal token and expiry date. Subclasses may override this to handle additional state updates (e.g., persisting new refresh tokens).

def token_has_expired(self) -> bool:
125    def token_has_expired(self) -> bool:
126        """Returns True if the token is expired"""
127        return ab_datetime_now() > self.get_token_expiry_date()

Returns True if the token is expired

def build_refresh_request_body(self) -> Mapping[str, Any]:
129    def build_refresh_request_body(self) -> Mapping[str, Any]:
130        """
131        Returns the request body to set on the refresh request.
132
133        Override to define additional parameters.
134
135        Client credentials (client_id and client_secret) are excluded from the body when
136        refresh_request_headers contains an Authorization header (e.g., Basic auth).
137        This is required by OAuth providers like Gong that expect credentials ONLY in the
138        Authorization header and reject requests that include them in both places.
139        """
140        # Check if credentials are being sent via Authorization header
141        headers = self.get_refresh_request_headers()
142        credentials_in_header = headers and "Authorization" in headers
143
144        # Only include client credentials in body if not already in header
145        include_client_credentials = not credentials_in_header
146
147        payload: MutableMapping[str, Any] = {
148            self.get_grant_type_name(): self.get_grant_type(),
149        }
150
151        # Only include client credentials in body if configured to do so and not in header
152        if include_client_credentials:
153            payload[self.get_client_id_name()] = self.get_client_id()
154            payload[self.get_client_secret_name()] = self.get_client_secret()
155
156        payload[self.get_refresh_token_name()] = self.get_refresh_token()
157
158        if self.get_scopes():
159            payload["scopes"] = self.get_scopes()
160
161        if self.get_refresh_request_body():
162            for key, val in self.get_refresh_request_body().items():
163                # We defer to existing oauth constructs over custom configured fields
164                if key not in payload:
165                    payload[key] = val
166
167        return payload

Returns the request body to set on the refresh request.

Override to define additional parameters.

Client credentials (client_id and client_secret) are excluded from the body when refresh_request_headers contains an Authorization header (e.g., Basic auth). This is required by OAuth providers like Gong that expect credentials ONLY in the Authorization header and reject requests that include them in both places.

def build_refresh_request_headers(self) -> Optional[Mapping[str, Any]]:
169    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
170        """
171        Returns the request headers to set on the refresh request
172
173        """
174        headers = self.get_refresh_request_headers()
175        return headers if headers else None

Returns the request headers to set on the refresh request

def refresh_access_token(self) -> Tuple[str, airbyte_cdk.utils.datetime_helpers.AirbyteDateTime]:
177    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
178        """
179        Returns the refresh token and its expiration datetime
180
181        :return: a tuple of (access_token, token_lifespan)
182        """
183        try:
184            response_json = self._make_handled_request()
185        except (
186            requests.exceptions.ConnectionError,
187            requests.exceptions.ConnectTimeout,
188            requests.exceptions.ReadTimeout,
189        ) as e:
190            raise AirbyteTracedException(
191                message="OAuth access token refresh request failed due to a network error.",
192                internal_message=f"Network error during OAuth token refresh after retries were exhausted: {e}",
193                failure_type=FailureType.transient_error,
194            ) from e
195        self._ensure_access_token_in_response(response_json)
196
197        return (
198            self._extract_access_token(response_json),
199            self._extract_token_expiry_date(response_json),
200        )

Returns the refresh token and its expiration datetime

Returns

a tuple of (access_token, token_lifespan)

@abstractmethod
def get_token_refresh_endpoint(self) -> Optional[str]:
503    @abstractmethod
504    def get_token_refresh_endpoint(self) -> Optional[str]:
505        """Returns the endpoint to refresh the access token"""

Returns the endpoint to refresh the access token

@abstractmethod
def get_client_id_name(self) -> str:
507    @abstractmethod
508    def get_client_id_name(self) -> str:
509        """The client id name to authenticate"""

The client id name to authenticate

@abstractmethod
def get_client_id(self) -> str:
511    @abstractmethod
512    def get_client_id(self) -> str:
513        """The client id to authenticate"""

The client id to authenticate

@abstractmethod
def get_client_secret_name(self) -> str:
515    @abstractmethod
516    def get_client_secret_name(self) -> str:
517        """The client secret name to authenticate"""

The client secret name to authenticate

@abstractmethod
def get_client_secret(self) -> str:
519    @abstractmethod
520    def get_client_secret(self) -> str:
521        """The client secret to authenticate"""

The client secret to authenticate

@abstractmethod
def get_refresh_token_name(self) -> str:
523    @abstractmethod
524    def get_refresh_token_name(self) -> str:
525        """The refresh token name to authenticate"""

The refresh token name to authenticate

@abstractmethod
def get_refresh_token(self) -> Optional[str]:
527    @abstractmethod
528    def get_refresh_token(self) -> Optional[str]:
529        """The token used to refresh the access token when it expires"""

The token used to refresh the access token when it expires

@abstractmethod
def get_scopes(self) -> List[str]:
531    @abstractmethod
532    def get_scopes(self) -> List[str]:
533        """List of requested scopes"""

List of requested scopes

@abstractmethod
def get_token_expiry_date(self) -> airbyte_cdk.utils.datetime_helpers.AirbyteDateTime:
535    @abstractmethod
536    def get_token_expiry_date(self) -> AirbyteDateTime:
537        """Expiration date of the access token"""

Expiration date of the access token

@abstractmethod
def set_token_expiry_date(self, value: airbyte_cdk.utils.datetime_helpers.AirbyteDateTime) -> None:
539    @abstractmethod
540    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
541        """Setter for access token expiration date"""

Setter for access token expiration date

@abstractmethod
def get_access_token_name(self) -> str:
543    @abstractmethod
544    def get_access_token_name(self) -> str:
545        """Field to extract access token from in the response"""

Field to extract access token from in the response

@abstractmethod
def get_expires_in_name(self) -> str:
547    @abstractmethod
548    def get_expires_in_name(self) -> str:
549        """Returns the expires_in field name"""

Returns the expires_in field name

@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
551    @abstractmethod
552    def get_refresh_request_body(self) -> Mapping[str, Any]:
553        """Returns the request body to set on the refresh request"""

Returns the request body to set on the refresh request

@abstractmethod
def get_refresh_request_headers(self) -> Mapping[str, Any]:
555    @abstractmethod
556    def get_refresh_request_headers(self) -> Mapping[str, Any]:
557        """Returns the request headers to set on the refresh request"""

Returns the request headers to set on the refresh request

@abstractmethod
def get_grant_type(self) -> str:
559    @abstractmethod
560    def get_grant_type(self) -> str:
561        """Returns grant_type specified for requesting access_token"""

Returns grant_type specified for requesting access_token

@abstractmethod
def get_grant_type_name(self) -> str:
563    @abstractmethod
564    def get_grant_type_name(self) -> str:
565        """Returns grant_type specified name for requesting access_token"""

Returns grant_type specified name for requesting access_token

access_token: str
567    @property
568    @abstractmethod
569    def access_token(self) -> str:
570        """Returns the access token"""

Returns the access token