airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth

  1#
  2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
  3#
  4
  5import logging
  6import threading
  7from abc import abstractmethod
  8from datetime import timedelta
  9from json import JSONDecodeError
 10from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
 11
 12import backoff
 13import requests
 14from requests.auth import AuthBase
 15
 16from airbyte_cdk.models import FailureType, Level
 17from airbyte_cdk.sources.http_logger import format_http_message
 18from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
 19from airbyte_cdk.utils import AirbyteTracedException
 20from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets
 21from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse
 22
 23from ..exceptions import DefaultBackoffException
 24
 25logger = logging.getLogger("airbyte")
 26_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()
 27
 28
 29class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
 30    """
 31    Raised when the max level of recursion is reached, when trying to
 32    find-and-get the target key, during the `_make_handled_request`
 33    """
 34
 35
 36class AbstractOauth2Authenticator(AuthBase):
 37    """
 38    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 39    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 40    delegating that behavior to the classes implementing the interface.
 41    """
 42
 43    _NO_STREAM_NAME = None
 44
 45    # Class-level lock to prevent concurrent token refresh across multiple authenticator instances.
 46    # This is necessary because multiple streams may share the same OAuth credentials (refresh token)
 47    # through the connector config. Without this lock, concurrent refresh attempts can cause race
 48    # conditions where one stream successfully refreshes the token while others fail because the
 49    # refresh token has been invalidated (especially for single-use refresh tokens).
 50    _token_refresh_lock: threading.Lock = threading.Lock()
 51
 52    def __init__(
 53        self,
 54        refresh_token_error_status_codes: Tuple[int, ...] = (),
 55        refresh_token_error_key: str = "",
 56        refresh_token_error_values: Tuple[str, ...] = (),
 57    ) -> None:
 58        """
 59        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 60        then http errors with such params will be wrapped in AirbyteTracedException.
 61        """
 62        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 63        self._refresh_token_error_key = refresh_token_error_key
 64        self._refresh_token_error_values = refresh_token_error_values
 65
 66    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 67        """Attach the HTTP headers required to authenticate on the HTTP request"""
 68        request.headers.update(self.get_auth_header())
 69        return request
 70
 71    @property
 72    def _is_access_token_flow(self) -> bool:
 73        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 74
 75    @property
 76    def token_expiry_is_time_of_expiration(self) -> bool:
 77        """
 78        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 79        """
 80
 81        return False
 82
 83    @property
 84    def token_expiry_date_format(self) -> Optional[str]:
 85        """
 86        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 87        """
 88
 89        return None
 90
 91    def get_auth_header(self) -> Mapping[str, Any]:
 92        """HTTP header to set on the requests"""
 93        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 94        return {"Authorization": f"Bearer {token}"}
 95
 96    def get_access_token(self) -> str:
 97        """
 98        Returns the access token.
 99
100        This method uses double-checked locking to ensure thread-safe token refresh.
101        When multiple threads (streams) detect an expired token simultaneously, only one
102        will perform the refresh while others wait. After acquiring the lock, the token
103        expiry is re-checked to avoid redundant refresh attempts.
104        """
105        if self.token_has_expired():
106            with self._token_refresh_lock:
107                # Double-check after acquiring lock - another thread may have already refreshed
108                if self.token_has_expired():
109                    self.refresh_and_set_access_token()
110
111        return self.access_token
112
113    def refresh_and_set_access_token(self) -> None:
114        """Force refresh the access token and update internal state.
115
116        This method refreshes the access token regardless of whether it has expired,
117        and updates the internal token and expiry date. Subclasses may override this
118        to handle additional state updates (e.g., persisting new refresh tokens).
119        """
120        token, expires_in = self.refresh_access_token()
121        self.access_token = token
122        self.set_token_expiry_date(expires_in)
123
124    def token_has_expired(self) -> bool:
125        """Returns True if the token is expired"""
126        return ab_datetime_now() > self.get_token_expiry_date()
127
128    def build_refresh_request_body(self) -> Mapping[str, Any]:
129        """
130        Returns the request body to set on the refresh request.
131
132        Override to define additional parameters.
133
134        Client credentials (client_id and client_secret) are excluded from the body when
135        refresh_request_headers contains an Authorization header (e.g., Basic auth).
136        This is required by OAuth providers like Gong that expect credentials ONLY in the
137        Authorization header and reject requests that include them in both places.
138        """
139        # Check if credentials are being sent via Authorization header
140        headers = self.get_refresh_request_headers()
141        credentials_in_header = headers and "Authorization" in headers
142
143        # Only include client credentials in body if not already in header
144        include_client_credentials = not credentials_in_header
145
146        payload: MutableMapping[str, Any] = {
147            self.get_grant_type_name(): self.get_grant_type(),
148        }
149
150        # Only include client credentials in body if configured to do so and not in header
151        if include_client_credentials:
152            payload[self.get_client_id_name()] = self.get_client_id()
153            payload[self.get_client_secret_name()] = self.get_client_secret()
154
155        payload[self.get_refresh_token_name()] = self.get_refresh_token()
156
157        if self.get_scopes():
158            payload["scopes"] = self.get_scopes()
159
160        if self.get_refresh_request_body():
161            for key, val in self.get_refresh_request_body().items():
162                # We defer to existing oauth constructs over custom configured fields
163                if key not in payload:
164                    payload[key] = val
165
166        return payload
167
168    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
169        """
170        Returns the request headers to set on the refresh request
171
172        """
173        headers = self.get_refresh_request_headers()
174        return headers if headers else None
175
176    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
177        """
178        Returns the refresh token and its expiration datetime
179
180        :return: a tuple of (access_token, token_lifespan)
181        """
182        response_json = self._make_handled_request()
183        self._ensure_access_token_in_response(response_json)
184
185        return (
186            self._extract_access_token(response_json),
187            self._extract_token_expiry_date(response_json),
188        )
189
190    # ----------------
191    # PRIVATE METHODS
192    # ----------------
193
194    def _default_token_expiry_date(self) -> AirbyteDateTime:
195        """
196        Returns the default token expiry date
197        """
198        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
199        default_token_expiry_duration_hours = 1  # 1 hour
200        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
201
202    def _wrap_refresh_token_exception(
203        self, exception: requests.exceptions.RequestException
204    ) -> bool:
205        """
206        Wraps and handles exceptions that occur during the refresh token process.
207
208        This method checks if the provided exception is related to a refresh token error
209        by examining the response status code and specific error content.
210
211        Args:
212            exception (requests.exceptions.RequestException): The exception raised during the request.
213
214        Returns:
215            bool: True if the exception is related to a refresh token error, False otherwise.
216        """
217        try:
218            if exception.response is not None:
219                exception_content = exception.response.json()
220            else:
221                return False
222        except JSONDecodeError:
223            return False
224        return (
225            exception.response.status_code in self._refresh_token_error_status_codes
226            and exception_content.get(self._refresh_token_error_key)
227            in self._refresh_token_error_values
228        )
229
230    @backoff.on_exception(
231        backoff.expo,
232        DefaultBackoffException,
233        on_backoff=lambda details: logger.info(
234            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
235        ),
236        max_time=300,
237    )
238    def _make_handled_request(self) -> Any:
239        """
240        Makes a handled HTTP request to refresh an OAuth token.
241
242        This method sends a POST request to the token refresh endpoint with the necessary
243        headers and body to obtain a new access token. It handles various exceptions that
244        may occur during the request and logs the response for troubleshooting purposes.
245
246        Returns:
247            Mapping[str, Any]: The JSON response from the token refresh endpoint.
248
249        Raises:
250            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
251                                     or any 5xx server error.
252            AirbyteTracedException: If the refresh token is invalid or expired, prompting
253                                    re-authentication.
254            Exception: For any other exceptions that occur during the request.
255        """
256        try:
257            response = requests.request(
258                method="POST",
259                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
260                data=self.build_refresh_request_body(),
261                headers=self.build_refresh_request_headers(),
262            )
263
264            if not response.ok:
265                # log the response even if the request failed for troubleshooting purposes
266                self._log_response(response)
267                response.raise_for_status()
268
269            response_json = response.json()
270
271            try:
272                # extract the access token and add to secrets to avoid logging the raw value
273                access_key = self._extract_access_token(response_json)
274                if access_key:
275                    add_to_secrets(access_key)
276            except ResponseKeysMaxRecurtionReached as e:
277                # could not find the access token in the response, so do nothing
278                pass
279
280            self._log_response(response)
281
282            return response_json
283        except requests.exceptions.RequestException as e:
284            if e.response is not None:
285                if e.response.status_code == 429 or e.response.status_code >= 500:
286                    raise DefaultBackoffException(
287                        request=e.response.request,
288                        response=e.response,
289                        failure_type=FailureType.transient_error,
290                    )
291            if self._wrap_refresh_token_exception(e):
292                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
293                raise AirbyteTracedException(
294                    internal_message=message, message=message, failure_type=FailureType.config_error
295                )
296            raise
297        except Exception as e:
298            raise Exception(f"Error while refreshing access token: {e}") from e
299
300    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
301        """
302        Ensures that the access token is present in the response data.
303
304        This method attempts to extract the access token from the provided response data.
305        If the access token is not found, it raises an exception indicating that the token
306        refresh API response was missing the access token.
307
308        Args:
309            response_data (Mapping[str, Any]): The response data from which to extract the access token.
310
311        Raises:
312            Exception: If the access token is not found in the response data.
313            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
314        """
315        try:
316            access_key = self._extract_access_token(response_data)
317            if not access_key:
318                raise Exception(
319                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
320                )
321        except ResponseKeysMaxRecurtionReached as e:
322            raise e
323
324    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
325        """
326        Parse a string or integer token expiration date into a datetime object
327
328        :return: expiration datetime
329        """
330        if self.token_expiry_is_time_of_expiration:
331            if not self.token_expiry_date_format:
332                raise ValueError(
333                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
334                )
335            try:
336                return ab_datetime_parse(str(value))
337            except ValueError as e:
338                raise ValueError(f"Invalid token expiry date format: {e}")
339        else:
340            try:
341                # Only accept numeric values (as int/float/string) when no format specified
342                seconds = int(float(str(value)))
343                return ab_datetime_now() + timedelta(seconds=seconds)
344            except (ValueError, TypeError):
345                raise ValueError(
346                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
347                )
348
349    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
350        """
351        Extracts the access token from the given response data.
352
353        Args:
354            response_data (Mapping[str, Any]): The response data from which to extract the access token.
355
356        Returns:
357            str: The extracted access token.
358        """
359        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
360
361    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
362        """
363        Extracts the refresh token from the given response data.
364
365        Args:
366            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
367
368        Returns:
369            str: The extracted refresh token.
370        """
371        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
372
373    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
374        """
375        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
376
377        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
378
379        Args:
380            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
381
382        Returns:
383            The extracted token_expiry_date or None if not found.
384        """
385        expires_in = self._find_and_get_value_from_response(
386            response_data, self.get_expires_in_name()
387        )
388        if expires_in is not None:
389            return self._parse_token_expiration_date(expires_in)
390
391        # expires_in is None
392        existing_expiry_date = self.get_token_expiry_date()
393        if existing_expiry_date and not self.token_has_expired():
394            return existing_expiry_date
395
396        return self._default_token_expiry_date()
397
398    def _find_and_get_value_from_response(
399        self,
400        response_data: Mapping[str, Any],
401        key_name: str,
402        max_depth: int = 5,
403        current_depth: int = 0,
404    ) -> Any:
405        """
406        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
407
408        Args:
409            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
410            key_name (str): The key to search for in the response data.
411            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
412            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
413
414        Returns:
415            Any: The value associated with the specified key if found, otherwise None.
416
417        Raises:
418            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
419        """
420        if current_depth > max_depth:
421            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
422            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
423            raise ResponseKeysMaxRecurtionReached(
424                internal_message=message, message=message, failure_type=FailureType.config_error
425            )
426
427        if isinstance(response_data, dict):
428            # get from the root level
429            if key_name in response_data:
430                return response_data[key_name]
431
432            # get from the nested object
433            for _, value in response_data.items():
434                result = self._find_and_get_value_from_response(
435                    value, key_name, max_depth, current_depth + 1
436                )
437                if result is not None:
438                    return result
439
440        # get from the nested array object
441        elif isinstance(response_data, list):
442            for item in response_data:
443                result = self._find_and_get_value_from_response(
444                    item, key_name, max_depth, current_depth + 1
445                )
446                if result is not None:
447                    return result
448
449        return None
450
451    @property
452    def _message_repository(self) -> Optional[MessageRepository]:
453        """
454        The implementation can define a message_repository if it wants debugging logs for HTTP requests
455        """
456        return _NOOP_MESSAGE_REPOSITORY
457
458    def _log_response(self, response: requests.Response) -> None:
459        """
460        Logs the HTTP response using the message repository if it is available.
461
462        Args:
463            response (requests.Response): The HTTP response to log.
464        """
465        if self._message_repository:
466            self._message_repository.log_message(
467                Level.DEBUG,
468                lambda: format_http_message(
469                    response,
470                    "Refresh token",
471                    "Obtains access token",
472                    self._NO_STREAM_NAME,
473                    is_auxiliary=True,
474                    type="AUTH",
475                ),
476            )
477
478    # ----------------
479    # ABSTR METHODS
480    # ----------------
481
482    @abstractmethod
483    def get_token_refresh_endpoint(self) -> Optional[str]:
484        """Returns the endpoint to refresh the access token"""
485
486    @abstractmethod
487    def get_client_id_name(self) -> str:
488        """The client id name to authenticate"""
489
490    @abstractmethod
491    def get_client_id(self) -> str:
492        """The client id to authenticate"""
493
494    @abstractmethod
495    def get_client_secret_name(self) -> str:
496        """The client secret name to authenticate"""
497
498    @abstractmethod
499    def get_client_secret(self) -> str:
500        """The client secret to authenticate"""
501
502    @abstractmethod
503    def get_refresh_token_name(self) -> str:
504        """The refresh token name to authenticate"""
505
506    @abstractmethod
507    def get_refresh_token(self) -> Optional[str]:
508        """The token used to refresh the access token when it expires"""
509
510    @abstractmethod
511    def get_scopes(self) -> List[str]:
512        """List of requested scopes"""
513
514    @abstractmethod
515    def get_token_expiry_date(self) -> AirbyteDateTime:
516        """Expiration date of the access token"""
517
518    @abstractmethod
519    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
520        """Setter for access token expiration date"""
521
522    @abstractmethod
523    def get_access_token_name(self) -> str:
524        """Field to extract access token from in the response"""
525
526    @abstractmethod
527    def get_expires_in_name(self) -> str:
528        """Returns the expires_in field name"""
529
530    @abstractmethod
531    def get_refresh_request_body(self) -> Mapping[str, Any]:
532        """Returns the request body to set on the refresh request"""
533
534    @abstractmethod
535    def get_refresh_request_headers(self) -> Mapping[str, Any]:
536        """Returns the request headers to set on the refresh request"""
537
538    @abstractmethod
539    def get_grant_type(self) -> str:
540        """Returns grant_type specified for requesting access_token"""
541
542    @abstractmethod
543    def get_grant_type_name(self) -> str:
544        """Returns grant_type specified name for requesting access_token"""
545
546    @property
547    @abstractmethod
548    def access_token(self) -> str:
549        """Returns the access token"""
550
551    @access_token.setter
552    @abstractmethod
553    def access_token(self, value: str) -> str:
554        """Setter for the access token"""
logger = <Logger airbyte (INFO)>
class ResponseKeysMaxRecurtionReached(airbyte_cdk.utils.traced_exception.AirbyteTracedException):
30class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
31    """
32    Raised when the max level of recursion is reached, when trying to
33    find-and-get the target key, during the `_make_handled_request`
34    """

Raised when the max level of recursion is reached, when trying to find-and-get the target key, during the _make_handled_request

class AbstractOauth2Authenticator(requests.auth.AuthBase):
 37class AbstractOauth2Authenticator(AuthBase):
 38    """
 39    Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
 40    is designed to generically perform the refresh flow without regard to how config fields are get/set by
 41    delegating that behavior to the classes implementing the interface.
 42    """
 43
 44    _NO_STREAM_NAME = None
 45
 46    # Class-level lock to prevent concurrent token refresh across multiple authenticator instances.
 47    # This is necessary because multiple streams may share the same OAuth credentials (refresh token)
 48    # through the connector config. Without this lock, concurrent refresh attempts can cause race
 49    # conditions where one stream successfully refreshes the token while others fail because the
 50    # refresh token has been invalidated (especially for single-use refresh tokens).
 51    _token_refresh_lock: threading.Lock = threading.Lock()
 52
 53    def __init__(
 54        self,
 55        refresh_token_error_status_codes: Tuple[int, ...] = (),
 56        refresh_token_error_key: str = "",
 57        refresh_token_error_values: Tuple[str, ...] = (),
 58    ) -> None:
 59        """
 60        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
 61        then http errors with such params will be wrapped in AirbyteTracedException.
 62        """
 63        self._refresh_token_error_status_codes = refresh_token_error_status_codes
 64        self._refresh_token_error_key = refresh_token_error_key
 65        self._refresh_token_error_values = refresh_token_error_values
 66
 67    def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
 68        """Attach the HTTP headers required to authenticate on the HTTP request"""
 69        request.headers.update(self.get_auth_header())
 70        return request
 71
 72    @property
 73    def _is_access_token_flow(self) -> bool:
 74        return self.get_token_refresh_endpoint() is None and self.access_token is not None
 75
 76    @property
 77    def token_expiry_is_time_of_expiration(self) -> bool:
 78        """
 79        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
 80        """
 81
 82        return False
 83
 84    @property
 85    def token_expiry_date_format(self) -> Optional[str]:
 86        """
 87        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
 88        """
 89
 90        return None
 91
 92    def get_auth_header(self) -> Mapping[str, Any]:
 93        """HTTP header to set on the requests"""
 94        token = self.access_token if self._is_access_token_flow else self.get_access_token()
 95        return {"Authorization": f"Bearer {token}"}
 96
 97    def get_access_token(self) -> str:
 98        """
 99        Returns the access token.
100
101        This method uses double-checked locking to ensure thread-safe token refresh.
102        When multiple threads (streams) detect an expired token simultaneously, only one
103        will perform the refresh while others wait. After acquiring the lock, the token
104        expiry is re-checked to avoid redundant refresh attempts.
105        """
106        if self.token_has_expired():
107            with self._token_refresh_lock:
108                # Double-check after acquiring lock - another thread may have already refreshed
109                if self.token_has_expired():
110                    self.refresh_and_set_access_token()
111
112        return self.access_token
113
114    def refresh_and_set_access_token(self) -> None:
115        """Force refresh the access token and update internal state.
116
117        This method refreshes the access token regardless of whether it has expired,
118        and updates the internal token and expiry date. Subclasses may override this
119        to handle additional state updates (e.g., persisting new refresh tokens).
120        """
121        token, expires_in = self.refresh_access_token()
122        self.access_token = token
123        self.set_token_expiry_date(expires_in)
124
125    def token_has_expired(self) -> bool:
126        """Returns True if the token is expired"""
127        return ab_datetime_now() > self.get_token_expiry_date()
128
129    def build_refresh_request_body(self) -> Mapping[str, Any]:
130        """
131        Returns the request body to set on the refresh request.
132
133        Override to define additional parameters.
134
135        Client credentials (client_id and client_secret) are excluded from the body when
136        refresh_request_headers contains an Authorization header (e.g., Basic auth).
137        This is required by OAuth providers like Gong that expect credentials ONLY in the
138        Authorization header and reject requests that include them in both places.
139        """
140        # Check if credentials are being sent via Authorization header
141        headers = self.get_refresh_request_headers()
142        credentials_in_header = headers and "Authorization" in headers
143
144        # Only include client credentials in body if not already in header
145        include_client_credentials = not credentials_in_header
146
147        payload: MutableMapping[str, Any] = {
148            self.get_grant_type_name(): self.get_grant_type(),
149        }
150
151        # Only include client credentials in body if configured to do so and not in header
152        if include_client_credentials:
153            payload[self.get_client_id_name()] = self.get_client_id()
154            payload[self.get_client_secret_name()] = self.get_client_secret()
155
156        payload[self.get_refresh_token_name()] = self.get_refresh_token()
157
158        if self.get_scopes():
159            payload["scopes"] = self.get_scopes()
160
161        if self.get_refresh_request_body():
162            for key, val in self.get_refresh_request_body().items():
163                # We defer to existing oauth constructs over custom configured fields
164                if key not in payload:
165                    payload[key] = val
166
167        return payload
168
169    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
170        """
171        Returns the request headers to set on the refresh request
172
173        """
174        headers = self.get_refresh_request_headers()
175        return headers if headers else None
176
177    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
178        """
179        Returns the refresh token and its expiration datetime
180
181        :return: a tuple of (access_token, token_lifespan)
182        """
183        response_json = self._make_handled_request()
184        self._ensure_access_token_in_response(response_json)
185
186        return (
187            self._extract_access_token(response_json),
188            self._extract_token_expiry_date(response_json),
189        )
190
191    # ----------------
192    # PRIVATE METHODS
193    # ----------------
194
195    def _default_token_expiry_date(self) -> AirbyteDateTime:
196        """
197        Returns the default token expiry date
198        """
199        # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
200        default_token_expiry_duration_hours = 1  # 1 hour
201        return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
202
203    def _wrap_refresh_token_exception(
204        self, exception: requests.exceptions.RequestException
205    ) -> bool:
206        """
207        Wraps and handles exceptions that occur during the refresh token process.
208
209        This method checks if the provided exception is related to a refresh token error
210        by examining the response status code and specific error content.
211
212        Args:
213            exception (requests.exceptions.RequestException): The exception raised during the request.
214
215        Returns:
216            bool: True if the exception is related to a refresh token error, False otherwise.
217        """
218        try:
219            if exception.response is not None:
220                exception_content = exception.response.json()
221            else:
222                return False
223        except JSONDecodeError:
224            return False
225        return (
226            exception.response.status_code in self._refresh_token_error_status_codes
227            and exception_content.get(self._refresh_token_error_key)
228            in self._refresh_token_error_values
229        )
230
231    @backoff.on_exception(
232        backoff.expo,
233        DefaultBackoffException,
234        on_backoff=lambda details: logger.info(
235            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
236        ),
237        max_time=300,
238    )
239    def _make_handled_request(self) -> Any:
240        """
241        Makes a handled HTTP request to refresh an OAuth token.
242
243        This method sends a POST request to the token refresh endpoint with the necessary
244        headers and body to obtain a new access token. It handles various exceptions that
245        may occur during the request and logs the response for troubleshooting purposes.
246
247        Returns:
248            Mapping[str, Any]: The JSON response from the token refresh endpoint.
249
250        Raises:
251            DefaultBackoffException: If the response status code is 429 (Too Many Requests)
252                                     or any 5xx server error.
253            AirbyteTracedException: If the refresh token is invalid or expired, prompting
254                                    re-authentication.
255            Exception: For any other exceptions that occur during the request.
256        """
257        try:
258            response = requests.request(
259                method="POST",
260                url=self.get_token_refresh_endpoint(),  # type: ignore # returns None, if not provided, but str | bytes is expected.
261                data=self.build_refresh_request_body(),
262                headers=self.build_refresh_request_headers(),
263            )
264
265            if not response.ok:
266                # log the response even if the request failed for troubleshooting purposes
267                self._log_response(response)
268                response.raise_for_status()
269
270            response_json = response.json()
271
272            try:
273                # extract the access token and add to secrets to avoid logging the raw value
274                access_key = self._extract_access_token(response_json)
275                if access_key:
276                    add_to_secrets(access_key)
277            except ResponseKeysMaxRecurtionReached as e:
278                # could not find the access token in the response, so do nothing
279                pass
280
281            self._log_response(response)
282
283            return response_json
284        except requests.exceptions.RequestException as e:
285            if e.response is not None:
286                if e.response.status_code == 429 or e.response.status_code >= 500:
287                    raise DefaultBackoffException(
288                        request=e.response.request,
289                        response=e.response,
290                        failure_type=FailureType.transient_error,
291                    )
292            if self._wrap_refresh_token_exception(e):
293                message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
294                raise AirbyteTracedException(
295                    internal_message=message, message=message, failure_type=FailureType.config_error
296                )
297            raise
298        except Exception as e:
299            raise Exception(f"Error while refreshing access token: {e}") from e
300
301    def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
302        """
303        Ensures that the access token is present in the response data.
304
305        This method attempts to extract the access token from the provided response data.
306        If the access token is not found, it raises an exception indicating that the token
307        refresh API response was missing the access token.
308
309        Args:
310            response_data (Mapping[str, Any]): The response data from which to extract the access token.
311
312        Raises:
313            Exception: If the access token is not found in the response data.
314            ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
315        """
316        try:
317            access_key = self._extract_access_token(response_data)
318            if not access_key:
319                raise Exception(
320                    f"Token refresh API response was missing access token {self.get_access_token_name()}"
321                )
322        except ResponseKeysMaxRecurtionReached as e:
323            raise e
324
325    def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
326        """
327        Parse a string or integer token expiration date into a datetime object
328
329        :return: expiration datetime
330        """
331        if self.token_expiry_is_time_of_expiration:
332            if not self.token_expiry_date_format:
333                raise ValueError(
334                    f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
335                )
336            try:
337                return ab_datetime_parse(str(value))
338            except ValueError as e:
339                raise ValueError(f"Invalid token expiry date format: {e}")
340        else:
341            try:
342                # Only accept numeric values (as int/float/string) when no format specified
343                seconds = int(float(str(value)))
344                return ab_datetime_now() + timedelta(seconds=seconds)
345            except (ValueError, TypeError):
346                raise ValueError(
347                    f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
348                )
349
350    def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
351        """
352        Extracts the access token from the given response data.
353
354        Args:
355            response_data (Mapping[str, Any]): The response data from which to extract the access token.
356
357        Returns:
358            str: The extracted access token.
359        """
360        return self._find_and_get_value_from_response(response_data, self.get_access_token_name())
361
362    def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
363        """
364        Extracts the refresh token from the given response data.
365
366        Args:
367            response_data (Mapping[str, Any]): The response data from which to extract the refresh token.
368
369        Returns:
370            str: The extracted refresh token.
371        """
372        return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
373
374    def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
375        """
376        Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
377
378        If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
379
380        Args:
381            response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
382
383        Returns:
384            The extracted token_expiry_date or None if not found.
385        """
386        expires_in = self._find_and_get_value_from_response(
387            response_data, self.get_expires_in_name()
388        )
389        if expires_in is not None:
390            return self._parse_token_expiration_date(expires_in)
391
392        # expires_in is None
393        existing_expiry_date = self.get_token_expiry_date()
394        if existing_expiry_date and not self.token_has_expired():
395            return existing_expiry_date
396
397        return self._default_token_expiry_date()
398
399    def _find_and_get_value_from_response(
400        self,
401        response_data: Mapping[str, Any],
402        key_name: str,
403        max_depth: int = 5,
404        current_depth: int = 0,
405    ) -> Any:
406        """
407        Recursively searches for a specified key in a nested dictionary or list and returns its value if found.
408
409        Args:
410            response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
411            key_name (str): The key to search for in the response data.
412            max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
413            current_depth (int, optional): The current depth of the recursion. Defaults to 0.
414
415        Returns:
416            Any: The value associated with the specified key if found, otherwise None.
417
418        Raises:
419            AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
420        """
421        if current_depth > max_depth:
422            # this is needed to avoid an inf loop, possible with a very deep nesting observed.
423            message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
424            raise ResponseKeysMaxRecurtionReached(
425                internal_message=message, message=message, failure_type=FailureType.config_error
426            )
427
428        if isinstance(response_data, dict):
429            # get from the root level
430            if key_name in response_data:
431                return response_data[key_name]
432
433            # get from the nested object
434            for _, value in response_data.items():
435                result = self._find_and_get_value_from_response(
436                    value, key_name, max_depth, current_depth + 1
437                )
438                if result is not None:
439                    return result
440
441        # get from the nested array object
442        elif isinstance(response_data, list):
443            for item in response_data:
444                result = self._find_and_get_value_from_response(
445                    item, key_name, max_depth, current_depth + 1
446                )
447                if result is not None:
448                    return result
449
450        return None
451
452    @property
453    def _message_repository(self) -> Optional[MessageRepository]:
454        """
455        The implementation can define a message_repository if it wants debugging logs for HTTP requests
456        """
457        return _NOOP_MESSAGE_REPOSITORY
458
459    def _log_response(self, response: requests.Response) -> None:
460        """
461        Logs the HTTP response using the message repository if it is available.
462
463        Args:
464            response (requests.Response): The HTTP response to log.
465        """
466        if self._message_repository:
467            self._message_repository.log_message(
468                Level.DEBUG,
469                lambda: format_http_message(
470                    response,
471                    "Refresh token",
472                    "Obtains access token",
473                    self._NO_STREAM_NAME,
474                    is_auxiliary=True,
475                    type="AUTH",
476                ),
477            )
478
479    # ----------------
480    # ABSTR METHODS
481    # ----------------
482
483    @abstractmethod
484    def get_token_refresh_endpoint(self) -> Optional[str]:
485        """Returns the endpoint to refresh the access token"""
486
487    @abstractmethod
488    def get_client_id_name(self) -> str:
489        """The client id name to authenticate"""
490
491    @abstractmethod
492    def get_client_id(self) -> str:
493        """The client id to authenticate"""
494
495    @abstractmethod
496    def get_client_secret_name(self) -> str:
497        """The client secret name to authenticate"""
498
499    @abstractmethod
500    def get_client_secret(self) -> str:
501        """The client secret to authenticate"""
502
503    @abstractmethod
504    def get_refresh_token_name(self) -> str:
505        """The refresh token name to authenticate"""
506
507    @abstractmethod
508    def get_refresh_token(self) -> Optional[str]:
509        """The token used to refresh the access token when it expires"""
510
511    @abstractmethod
512    def get_scopes(self) -> List[str]:
513        """List of requested scopes"""
514
515    @abstractmethod
516    def get_token_expiry_date(self) -> AirbyteDateTime:
517        """Expiration date of the access token"""
518
519    @abstractmethod
520    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
521        """Setter for access token expiration date"""
522
523    @abstractmethod
524    def get_access_token_name(self) -> str:
525        """Field to extract access token from in the response"""
526
527    @abstractmethod
528    def get_expires_in_name(self) -> str:
529        """Returns the expires_in field name"""
530
531    @abstractmethod
532    def get_refresh_request_body(self) -> Mapping[str, Any]:
533        """Returns the request body to set on the refresh request"""
534
535    @abstractmethod
536    def get_refresh_request_headers(self) -> Mapping[str, Any]:
537        """Returns the request headers to set on the refresh request"""
538
539    @abstractmethod
540    def get_grant_type(self) -> str:
541        """Returns grant_type specified for requesting access_token"""
542
543    @abstractmethod
544    def get_grant_type_name(self) -> str:
545        """Returns grant_type specified name for requesting access_token"""
546
547    @property
548    @abstractmethod
549    def access_token(self) -> str:
550        """Returns the access token"""
551
552    @access_token.setter
553    @abstractmethod
554    def access_token(self, value: str) -> str:
555        """Setter for the access token"""

Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.

AbstractOauth2Authenticator( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = '', refresh_token_error_values: Tuple[str, ...] = ())
53    def __init__(
54        self,
55        refresh_token_error_status_codes: Tuple[int, ...] = (),
56        refresh_token_error_key: str = "",
57        refresh_token_error_values: Tuple[str, ...] = (),
58    ) -> None:
59        """
60        If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set,
61        then http errors with such params will be wrapped in AirbyteTracedException.
62        """
63        self._refresh_token_error_status_codes = refresh_token_error_status_codes
64        self._refresh_token_error_key = refresh_token_error_key
65        self._refresh_token_error_values = refresh_token_error_values

If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.

token_expiry_is_time_of_expiration: bool
76    @property
77    def token_expiry_is_time_of_expiration(self) -> bool:
78        """
79        Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
80        """
81
82        return False

Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.

token_expiry_date_format: Optional[str]
84    @property
85    def token_expiry_date_format(self) -> Optional[str]:
86        """
87        Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
88        """
89
90        return None

Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires

def get_auth_header(self) -> Mapping[str, Any]:
92    def get_auth_header(self) -> Mapping[str, Any]:
93        """HTTP header to set on the requests"""
94        token = self.access_token if self._is_access_token_flow else self.get_access_token()
95        return {"Authorization": f"Bearer {token}"}

HTTP header to set on the requests

def get_access_token(self) -> str:
 97    def get_access_token(self) -> str:
 98        """
 99        Returns the access token.
100
101        This method uses double-checked locking to ensure thread-safe token refresh.
102        When multiple threads (streams) detect an expired token simultaneously, only one
103        will perform the refresh while others wait. After acquiring the lock, the token
104        expiry is re-checked to avoid redundant refresh attempts.
105        """
106        if self.token_has_expired():
107            with self._token_refresh_lock:
108                # Double-check after acquiring lock - another thread may have already refreshed
109                if self.token_has_expired():
110                    self.refresh_and_set_access_token()
111
112        return self.access_token

Returns the access token.

This method uses double-checked locking to ensure thread-safe token refresh. When multiple threads (streams) detect an expired token simultaneously, only one will perform the refresh while others wait. After acquiring the lock, the token expiry is re-checked to avoid redundant refresh attempts.

def refresh_and_set_access_token(self) -> None:
114    def refresh_and_set_access_token(self) -> None:
115        """Force refresh the access token and update internal state.
116
117        This method refreshes the access token regardless of whether it has expired,
118        and updates the internal token and expiry date. Subclasses may override this
119        to handle additional state updates (e.g., persisting new refresh tokens).
120        """
121        token, expires_in = self.refresh_access_token()
122        self.access_token = token
123        self.set_token_expiry_date(expires_in)

Force refresh the access token and update internal state.

This method refreshes the access token regardless of whether it has expired, and updates the internal token and expiry date. Subclasses may override this to handle additional state updates (e.g., persisting new refresh tokens).

def token_has_expired(self) -> bool:
125    def token_has_expired(self) -> bool:
126        """Returns True if the token is expired"""
127        return ab_datetime_now() > self.get_token_expiry_date()

Returns True if the token is expired

def build_refresh_request_body(self) -> Mapping[str, Any]:
129    def build_refresh_request_body(self) -> Mapping[str, Any]:
130        """
131        Returns the request body to set on the refresh request.
132
133        Override to define additional parameters.
134
135        Client credentials (client_id and client_secret) are excluded from the body when
136        refresh_request_headers contains an Authorization header (e.g., Basic auth).
137        This is required by OAuth providers like Gong that expect credentials ONLY in the
138        Authorization header and reject requests that include them in both places.
139        """
140        # Check if credentials are being sent via Authorization header
141        headers = self.get_refresh_request_headers()
142        credentials_in_header = headers and "Authorization" in headers
143
144        # Only include client credentials in body if not already in header
145        include_client_credentials = not credentials_in_header
146
147        payload: MutableMapping[str, Any] = {
148            self.get_grant_type_name(): self.get_grant_type(),
149        }
150
151        # Only include client credentials in body if configured to do so and not in header
152        if include_client_credentials:
153            payload[self.get_client_id_name()] = self.get_client_id()
154            payload[self.get_client_secret_name()] = self.get_client_secret()
155
156        payload[self.get_refresh_token_name()] = self.get_refresh_token()
157
158        if self.get_scopes():
159            payload["scopes"] = self.get_scopes()
160
161        if self.get_refresh_request_body():
162            for key, val in self.get_refresh_request_body().items():
163                # We defer to existing oauth constructs over custom configured fields
164                if key not in payload:
165                    payload[key] = val
166
167        return payload

Returns the request body to set on the refresh request.

Override to define additional parameters.

Client credentials (client_id and client_secret) are excluded from the body when refresh_request_headers contains an Authorization header (e.g., Basic auth). This is required by OAuth providers like Gong that expect credentials ONLY in the Authorization header and reject requests that include them in both places.

def build_refresh_request_headers(self) -> Optional[Mapping[str, Any]]:
169    def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
170        """
171        Returns the request headers to set on the refresh request
172
173        """
174        headers = self.get_refresh_request_headers()
175        return headers if headers else None

Returns the request headers to set on the refresh request

def refresh_access_token(self) -> Tuple[str, airbyte_cdk.utils.datetime_helpers.AirbyteDateTime]:
177    def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
178        """
179        Returns the refresh token and its expiration datetime
180
181        :return: a tuple of (access_token, token_lifespan)
182        """
183        response_json = self._make_handled_request()
184        self._ensure_access_token_in_response(response_json)
185
186        return (
187            self._extract_access_token(response_json),
188            self._extract_token_expiry_date(response_json),
189        )

Returns the refresh token and its expiration datetime

Returns

a tuple of (access_token, token_lifespan)

@abstractmethod
def get_token_refresh_endpoint(self) -> Optional[str]:
483    @abstractmethod
484    def get_token_refresh_endpoint(self) -> Optional[str]:
485        """Returns the endpoint to refresh the access token"""

Returns the endpoint to refresh the access token

@abstractmethod
def get_client_id_name(self) -> str:
487    @abstractmethod
488    def get_client_id_name(self) -> str:
489        """The client id name to authenticate"""

The client id name to authenticate

@abstractmethod
def get_client_id(self) -> str:
491    @abstractmethod
492    def get_client_id(self) -> str:
493        """The client id to authenticate"""

The client id to authenticate

@abstractmethod
def get_client_secret_name(self) -> str:
495    @abstractmethod
496    def get_client_secret_name(self) -> str:
497        """The client secret name to authenticate"""

The client secret name to authenticate

@abstractmethod
def get_client_secret(self) -> str:
499    @abstractmethod
500    def get_client_secret(self) -> str:
501        """The client secret to authenticate"""

The client secret to authenticate

@abstractmethod
def get_refresh_token_name(self) -> str:
503    @abstractmethod
504    def get_refresh_token_name(self) -> str:
505        """The refresh token name to authenticate"""

The refresh token name to authenticate

@abstractmethod
def get_refresh_token(self) -> Optional[str]:
507    @abstractmethod
508    def get_refresh_token(self) -> Optional[str]:
509        """The token used to refresh the access token when it expires"""

The token used to refresh the access token when it expires

@abstractmethod
def get_scopes(self) -> List[str]:
511    @abstractmethod
512    def get_scopes(self) -> List[str]:
513        """List of requested scopes"""

List of requested scopes

@abstractmethod
def get_token_expiry_date(self) -> airbyte_cdk.utils.datetime_helpers.AirbyteDateTime:
515    @abstractmethod
516    def get_token_expiry_date(self) -> AirbyteDateTime:
517        """Expiration date of the access token"""

Expiration date of the access token

@abstractmethod
def set_token_expiry_date(self, value: airbyte_cdk.utils.datetime_helpers.AirbyteDateTime) -> None:
519    @abstractmethod
520    def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
521        """Setter for access token expiration date"""

Setter for access token expiration date

@abstractmethod
def get_access_token_name(self) -> str:
523    @abstractmethod
524    def get_access_token_name(self) -> str:
525        """Field to extract access token from in the response"""

Field to extract access token from in the response

@abstractmethod
def get_expires_in_name(self) -> str:
527    @abstractmethod
528    def get_expires_in_name(self) -> str:
529        """Returns the expires_in field name"""

Returns the expires_in field name

@abstractmethod
def get_refresh_request_body(self) -> Mapping[str, Any]:
531    @abstractmethod
532    def get_refresh_request_body(self) -> Mapping[str, Any]:
533        """Returns the request body to set on the refresh request"""

Returns the request body to set on the refresh request

@abstractmethod
def get_refresh_request_headers(self) -> Mapping[str, Any]:
535    @abstractmethod
536    def get_refresh_request_headers(self) -> Mapping[str, Any]:
537        """Returns the request headers to set on the refresh request"""

Returns the request headers to set on the refresh request

@abstractmethod
def get_grant_type(self) -> str:
539    @abstractmethod
540    def get_grant_type(self) -> str:
541        """Returns grant_type specified for requesting access_token"""

Returns grant_type specified for requesting access_token

@abstractmethod
def get_grant_type_name(self) -> str:
543    @abstractmethod
544    def get_grant_type_name(self) -> str:
545        """Returns grant_type specified name for requesting access_token"""

Returns grant_type specified name for requesting access_token

access_token: str
547    @property
548    @abstractmethod
549    def access_token(self) -> str:
550        """Returns the access token"""

Returns the access token