airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth
1# 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 3# 4 5import logging 6import threading 7from abc import abstractmethod 8from datetime import timedelta 9from json import JSONDecodeError 10from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union 11 12import backoff 13import requests 14from requests.auth import AuthBase 15 16from airbyte_cdk.models import FailureType, Level 17from airbyte_cdk.sources.http_logger import format_http_message 18from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository 19from airbyte_cdk.utils import AirbyteTracedException 20from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets 21from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse 22 23from ..exceptions import DefaultBackoffException 24 25logger = logging.getLogger("airbyte") 26_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() 27 28 29class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 30 """ 31 Raised when the max level of recursion is reached, when trying to 32 find-and-get the target key, during the `_make_handled_request` 33 """ 34 35 36class AbstractOauth2Authenticator(AuthBase): 37 """ 38 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 39 is designed to generically perform the refresh flow without regard to how config fields are get/set by 40 delegating that behavior to the classes implementing the interface. 41 """ 42 43 _NO_STREAM_NAME = None 44 45 # Class-level lock to prevent concurrent token refresh across multiple authenticator instances. 46 # This is necessary because multiple streams may share the same OAuth credentials (refresh token) 47 # through the connector config. Without this lock, concurrent refresh attempts can cause race 48 # conditions where one stream successfully refreshes the token while others fail because the 49 # refresh token has been invalidated (especially for single-use refresh tokens). 50 _token_refresh_lock: threading.Lock = threading.Lock() 51 52 def __init__( 53 self, 54 refresh_token_error_status_codes: Tuple[int, ...] = (), 55 refresh_token_error_key: str = "", 56 refresh_token_error_values: Tuple[str, ...] = (), 57 ) -> None: 58 """ 59 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 60 then http errors with such params will be wrapped in AirbyteTracedException. 61 """ 62 self._refresh_token_error_status_codes = refresh_token_error_status_codes 63 self._refresh_token_error_key = refresh_token_error_key 64 self._refresh_token_error_values = refresh_token_error_values 65 66 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 67 """Attach the HTTP headers required to authenticate on the HTTP request""" 68 request.headers.update(self.get_auth_header()) 69 return request 70 71 @property 72 def _is_access_token_flow(self) -> bool: 73 return self.get_token_refresh_endpoint() is None and self.access_token is not None 74 75 @property 76 def token_expiry_is_time_of_expiration(self) -> bool: 77 """ 78 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 79 """ 80 81 return False 82 83 @property 84 def token_expiry_date_format(self) -> Optional[str]: 85 """ 86 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 87 """ 88 89 return None 90 91 def get_auth_header(self) -> Mapping[str, Any]: 92 """HTTP header to set on the requests""" 93 token = self.access_token if self._is_access_token_flow else self.get_access_token() 94 return {"Authorization": f"Bearer {token}"} 95 96 def get_access_token(self) -> str: 97 """ 98 Returns the access token. 99 100 This method uses double-checked locking to ensure thread-safe token refresh. 101 When multiple threads (streams) detect an expired token simultaneously, only one 102 will perform the refresh while others wait. After acquiring the lock, the token 103 expiry is re-checked to avoid redundant refresh attempts. 104 """ 105 if self.token_has_expired(): 106 with self._token_refresh_lock: 107 # Double-check after acquiring lock - another thread may have already refreshed 108 if self.token_has_expired(): 109 self.refresh_and_set_access_token() 110 111 return self.access_token 112 113 def refresh_and_set_access_token(self) -> None: 114 """Force refresh the access token and update internal state. 115 116 This method refreshes the access token regardless of whether it has expired, 117 and updates the internal token and expiry date. Subclasses may override this 118 to handle additional state updates (e.g., persisting new refresh tokens). 119 """ 120 token, expires_in = self.refresh_access_token() 121 self.access_token = token 122 self.set_token_expiry_date(expires_in) 123 124 def token_has_expired(self) -> bool: 125 """Returns True if the token is expired""" 126 return ab_datetime_now() > self.get_token_expiry_date() 127 128 def build_refresh_request_body(self) -> Mapping[str, Any]: 129 """ 130 Returns the request body to set on the refresh request. 131 132 Override to define additional parameters. 133 134 Client credentials (client_id and client_secret) are excluded from the body when 135 refresh_request_headers contains an Authorization header (e.g., Basic auth). 136 This is required by OAuth providers like Gong that expect credentials ONLY in the 137 Authorization header and reject requests that include them in both places. 138 """ 139 # Check if credentials are being sent via Authorization header 140 headers = self.get_refresh_request_headers() 141 credentials_in_header = headers and "Authorization" in headers 142 143 # Only include client credentials in body if not already in header 144 include_client_credentials = not credentials_in_header 145 146 payload: MutableMapping[str, Any] = { 147 self.get_grant_type_name(): self.get_grant_type(), 148 } 149 150 # Only include client credentials in body if configured to do so and not in header 151 if include_client_credentials: 152 payload[self.get_client_id_name()] = self.get_client_id() 153 payload[self.get_client_secret_name()] = self.get_client_secret() 154 155 payload[self.get_refresh_token_name()] = self.get_refresh_token() 156 157 if self.get_scopes(): 158 payload["scopes"] = self.get_scopes() 159 160 if self.get_refresh_request_body(): 161 for key, val in self.get_refresh_request_body().items(): 162 # We defer to existing oauth constructs over custom configured fields 163 if key not in payload: 164 payload[key] = val 165 166 return payload 167 168 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 169 """ 170 Returns the request headers to set on the refresh request 171 172 """ 173 headers = self.get_refresh_request_headers() 174 return headers if headers else None 175 176 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 177 """ 178 Returns the refresh token and its expiration datetime 179 180 :return: a tuple of (access_token, token_lifespan) 181 """ 182 try: 183 response_json = self._make_handled_request() 184 except ( 185 requests.exceptions.ConnectionError, 186 requests.exceptions.ConnectTimeout, 187 requests.exceptions.ReadTimeout, 188 ) as e: 189 raise AirbyteTracedException( 190 message="OAuth access token refresh request failed due to a network error.", 191 internal_message=f"Network error during OAuth token refresh after retries were exhausted: {e}", 192 failure_type=FailureType.transient_error, 193 ) from e 194 self._ensure_access_token_in_response(response_json) 195 196 return ( 197 self._extract_access_token(response_json), 198 self._extract_token_expiry_date(response_json), 199 ) 200 201 # ---------------- 202 # PRIVATE METHODS 203 # ---------------- 204 205 def _default_token_expiry_date(self) -> AirbyteDateTime: 206 """ 207 Returns the default token expiry date 208 """ 209 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 210 default_token_expiry_duration_hours = 1 # 1 hour 211 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 212 213 def _wrap_refresh_token_exception( 214 self, exception: requests.exceptions.RequestException 215 ) -> bool: 216 """ 217 Wraps and handles exceptions that occur during the refresh token process. 218 219 This method checks if the provided exception is related to a refresh token error 220 by examining the response status code and specific error content. 221 222 Args: 223 exception (requests.exceptions.RequestException): The exception raised during the request. 224 225 Returns: 226 bool: True if the exception is related to a refresh token error, False otherwise. 227 """ 228 try: 229 if exception.response is not None: 230 exception_content = exception.response.json() 231 else: 232 return False 233 except JSONDecodeError: 234 return False 235 return ( 236 exception.response.status_code in self._refresh_token_error_status_codes 237 and exception_content.get(self._refresh_token_error_key) 238 in self._refresh_token_error_values 239 ) 240 241 @backoff.on_exception( 242 backoff.expo, 243 ( 244 DefaultBackoffException, 245 requests.exceptions.ConnectionError, 246 requests.exceptions.ConnectTimeout, 247 requests.exceptions.ReadTimeout, 248 ), 249 on_backoff=lambda details: logger.info( 250 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 251 ), 252 max_time=300, 253 ) 254 def _make_handled_request(self) -> Any: 255 """ 256 Makes a handled HTTP request to refresh an OAuth token. 257 258 This method sends a POST request to the token refresh endpoint with the necessary 259 headers and body to obtain a new access token. It handles various exceptions that 260 may occur during the request and logs the response for troubleshooting purposes. 261 262 Returns: 263 Mapping[str, Any]: The JSON response from the token refresh endpoint. 264 265 Raises: 266 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 267 or any 5xx server error. 268 AirbyteTracedException: If the refresh token is invalid or expired, prompting 269 re-authentication. 270 Exception: For any other exceptions that occur during the request. 271 """ 272 try: 273 response = requests.request( 274 method="POST", 275 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 276 data=self.build_refresh_request_body(), 277 headers=self.build_refresh_request_headers(), 278 ) 279 280 if not response.ok: 281 # log the response even if the request failed for troubleshooting purposes 282 self._log_response(response) 283 response.raise_for_status() 284 285 response_json = response.json() 286 287 try: 288 # extract the access token and add to secrets to avoid logging the raw value 289 access_key = self._extract_access_token(response_json) 290 if access_key: 291 add_to_secrets(access_key) 292 except ResponseKeysMaxRecurtionReached as e: 293 # could not find the access token in the response, so do nothing 294 pass 295 296 self._log_response(response) 297 298 return response_json 299 except requests.exceptions.RequestException as e: 300 if e.response is not None: 301 if e.response.status_code == 429 or e.response.status_code >= 500: 302 raise DefaultBackoffException( 303 request=e.response.request, 304 response=e.response, 305 failure_type=FailureType.transient_error, 306 ) 307 if self._wrap_refresh_token_exception(e): 308 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 309 raise AirbyteTracedException( 310 internal_message=message, message=message, failure_type=FailureType.config_error 311 ) 312 raise 313 except Exception as e: 314 raise AirbyteTracedException( 315 message="OAuth access token refresh request failed.", 316 internal_message=f"Unexpected error during OAuth token refresh: {e}", 317 failure_type=FailureType.system_error, 318 ) from e 319 320 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 321 """ 322 Ensures that the access token is present in the response data. 323 324 This method attempts to extract the access token from the provided response data. 325 If the access token is not found, it raises an exception indicating that the token 326 refresh API response was missing the access token. 327 328 Args: 329 response_data (Mapping[str, Any]): The response data from which to extract the access token. 330 331 Raises: 332 Exception: If the access token is not found in the response data. 333 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 334 """ 335 try: 336 access_key = self._extract_access_token(response_data) 337 if not access_key: 338 raise Exception( 339 f"Token refresh API response was missing access token {self.get_access_token_name()}" 340 ) 341 except ResponseKeysMaxRecurtionReached as e: 342 raise e 343 344 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 345 """ 346 Parse a string or integer token expiration date into a datetime object 347 348 :return: expiration datetime 349 """ 350 if self.token_expiry_is_time_of_expiration: 351 if not self.token_expiry_date_format: 352 raise ValueError( 353 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 354 ) 355 try: 356 return ab_datetime_parse(str(value)) 357 except ValueError as e: 358 raise ValueError(f"Invalid token expiry date format: {e}") 359 else: 360 try: 361 # Only accept numeric values (as int/float/string) when no format specified 362 seconds = int(float(str(value))) 363 return ab_datetime_now() + timedelta(seconds=seconds) 364 except (ValueError, TypeError): 365 raise ValueError( 366 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 367 ) 368 369 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 370 """ 371 Extracts the access token from the given response data. 372 373 Args: 374 response_data (Mapping[str, Any]): The response data from which to extract the access token. 375 376 Returns: 377 str: The extracted access token. 378 """ 379 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 380 381 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 382 """ 383 Extracts the refresh token from the given response data. 384 385 Args: 386 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 387 388 Returns: 389 str: The extracted refresh token. 390 """ 391 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 392 393 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 394 """ 395 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 396 397 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 398 399 Args: 400 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 401 402 Returns: 403 The extracted token_expiry_date or None if not found. 404 """ 405 expires_in = self._find_and_get_value_from_response( 406 response_data, self.get_expires_in_name() 407 ) 408 if expires_in is not None: 409 return self._parse_token_expiration_date(expires_in) 410 411 # expires_in is None 412 existing_expiry_date = self.get_token_expiry_date() 413 if existing_expiry_date and not self.token_has_expired(): 414 return existing_expiry_date 415 416 return self._default_token_expiry_date() 417 418 def _find_and_get_value_from_response( 419 self, 420 response_data: Mapping[str, Any], 421 key_name: str, 422 max_depth: int = 5, 423 current_depth: int = 0, 424 ) -> Any: 425 """ 426 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 427 428 Args: 429 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 430 key_name (str): The key to search for in the response data. 431 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 432 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 433 434 Returns: 435 Any: The value associated with the specified key if found, otherwise None. 436 437 Raises: 438 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 439 """ 440 if current_depth > max_depth: 441 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 442 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 443 raise ResponseKeysMaxRecurtionReached( 444 internal_message=message, message=message, failure_type=FailureType.config_error 445 ) 446 447 if isinstance(response_data, dict): 448 # get from the root level 449 if key_name in response_data: 450 return response_data[key_name] 451 452 # get from the nested object 453 for _, value in response_data.items(): 454 result = self._find_and_get_value_from_response( 455 value, key_name, max_depth, current_depth + 1 456 ) 457 if result is not None: 458 return result 459 460 # get from the nested array object 461 elif isinstance(response_data, list): 462 for item in response_data: 463 result = self._find_and_get_value_from_response( 464 item, key_name, max_depth, current_depth + 1 465 ) 466 if result is not None: 467 return result 468 469 return None 470 471 @property 472 def _message_repository(self) -> Optional[MessageRepository]: 473 """ 474 The implementation can define a message_repository if it wants debugging logs for HTTP requests 475 """ 476 return _NOOP_MESSAGE_REPOSITORY 477 478 def _log_response(self, response: requests.Response) -> None: 479 """ 480 Logs the HTTP response using the message repository if it is available. 481 482 Args: 483 response (requests.Response): The HTTP response to log. 484 """ 485 if self._message_repository: 486 self._message_repository.log_message( 487 Level.DEBUG, 488 lambda: format_http_message( 489 response, 490 "Refresh token", 491 "Obtains access token", 492 self._NO_STREAM_NAME, 493 is_auxiliary=True, 494 type="AUTH", 495 ), 496 ) 497 498 # ---------------- 499 # ABSTR METHODS 500 # ---------------- 501 502 @abstractmethod 503 def get_token_refresh_endpoint(self) -> Optional[str]: 504 """Returns the endpoint to refresh the access token""" 505 506 @abstractmethod 507 def get_client_id_name(self) -> str: 508 """The client id name to authenticate""" 509 510 @abstractmethod 511 def get_client_id(self) -> str: 512 """The client id to authenticate""" 513 514 @abstractmethod 515 def get_client_secret_name(self) -> str: 516 """The client secret name to authenticate""" 517 518 @abstractmethod 519 def get_client_secret(self) -> str: 520 """The client secret to authenticate""" 521 522 @abstractmethod 523 def get_refresh_token_name(self) -> str: 524 """The refresh token name to authenticate""" 525 526 @abstractmethod 527 def get_refresh_token(self) -> Optional[str]: 528 """The token used to refresh the access token when it expires""" 529 530 @abstractmethod 531 def get_scopes(self) -> List[str]: 532 """List of requested scopes""" 533 534 @abstractmethod 535 def get_token_expiry_date(self) -> AirbyteDateTime: 536 """Expiration date of the access token""" 537 538 @abstractmethod 539 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 540 """Setter for access token expiration date""" 541 542 @abstractmethod 543 def get_access_token_name(self) -> str: 544 """Field to extract access token from in the response""" 545 546 @abstractmethod 547 def get_expires_in_name(self) -> str: 548 """Returns the expires_in field name""" 549 550 @abstractmethod 551 def get_refresh_request_body(self) -> Mapping[str, Any]: 552 """Returns the request body to set on the refresh request""" 553 554 @abstractmethod 555 def get_refresh_request_headers(self) -> Mapping[str, Any]: 556 """Returns the request headers to set on the refresh request""" 557 558 @abstractmethod 559 def get_grant_type(self) -> str: 560 """Returns grant_type specified for requesting access_token""" 561 562 @abstractmethod 563 def get_grant_type_name(self) -> str: 564 """Returns grant_type specified name for requesting access_token""" 565 566 @property 567 @abstractmethod 568 def access_token(self) -> str: 569 """Returns the access token""" 570 571 @access_token.setter 572 @abstractmethod 573 def access_token(self, value: str) -> str: 574 """Setter for the access token"""
30class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 31 """ 32 Raised when the max level of recursion is reached, when trying to 33 find-and-get the target key, during the `_make_handled_request` 34 """
Raised when the max level of recursion is reached, when trying to
find-and-get the target key, during the _make_handled_request
37class AbstractOauth2Authenticator(AuthBase): 38 """ 39 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 40 is designed to generically perform the refresh flow without regard to how config fields are get/set by 41 delegating that behavior to the classes implementing the interface. 42 """ 43 44 _NO_STREAM_NAME = None 45 46 # Class-level lock to prevent concurrent token refresh across multiple authenticator instances. 47 # This is necessary because multiple streams may share the same OAuth credentials (refresh token) 48 # through the connector config. Without this lock, concurrent refresh attempts can cause race 49 # conditions where one stream successfully refreshes the token while others fail because the 50 # refresh token has been invalidated (especially for single-use refresh tokens). 51 _token_refresh_lock: threading.Lock = threading.Lock() 52 53 def __init__( 54 self, 55 refresh_token_error_status_codes: Tuple[int, ...] = (), 56 refresh_token_error_key: str = "", 57 refresh_token_error_values: Tuple[str, ...] = (), 58 ) -> None: 59 """ 60 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 61 then http errors with such params will be wrapped in AirbyteTracedException. 62 """ 63 self._refresh_token_error_status_codes = refresh_token_error_status_codes 64 self._refresh_token_error_key = refresh_token_error_key 65 self._refresh_token_error_values = refresh_token_error_values 66 67 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 68 """Attach the HTTP headers required to authenticate on the HTTP request""" 69 request.headers.update(self.get_auth_header()) 70 return request 71 72 @property 73 def _is_access_token_flow(self) -> bool: 74 return self.get_token_refresh_endpoint() is None and self.access_token is not None 75 76 @property 77 def token_expiry_is_time_of_expiration(self) -> bool: 78 """ 79 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 80 """ 81 82 return False 83 84 @property 85 def token_expiry_date_format(self) -> Optional[str]: 86 """ 87 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 88 """ 89 90 return None 91 92 def get_auth_header(self) -> Mapping[str, Any]: 93 """HTTP header to set on the requests""" 94 token = self.access_token if self._is_access_token_flow else self.get_access_token() 95 return {"Authorization": f"Bearer {token}"} 96 97 def get_access_token(self) -> str: 98 """ 99 Returns the access token. 100 101 This method uses double-checked locking to ensure thread-safe token refresh. 102 When multiple threads (streams) detect an expired token simultaneously, only one 103 will perform the refresh while others wait. After acquiring the lock, the token 104 expiry is re-checked to avoid redundant refresh attempts. 105 """ 106 if self.token_has_expired(): 107 with self._token_refresh_lock: 108 # Double-check after acquiring lock - another thread may have already refreshed 109 if self.token_has_expired(): 110 self.refresh_and_set_access_token() 111 112 return self.access_token 113 114 def refresh_and_set_access_token(self) -> None: 115 """Force refresh the access token and update internal state. 116 117 This method refreshes the access token regardless of whether it has expired, 118 and updates the internal token and expiry date. Subclasses may override this 119 to handle additional state updates (e.g., persisting new refresh tokens). 120 """ 121 token, expires_in = self.refresh_access_token() 122 self.access_token = token 123 self.set_token_expiry_date(expires_in) 124 125 def token_has_expired(self) -> bool: 126 """Returns True if the token is expired""" 127 return ab_datetime_now() > self.get_token_expiry_date() 128 129 def build_refresh_request_body(self) -> Mapping[str, Any]: 130 """ 131 Returns the request body to set on the refresh request. 132 133 Override to define additional parameters. 134 135 Client credentials (client_id and client_secret) are excluded from the body when 136 refresh_request_headers contains an Authorization header (e.g., Basic auth). 137 This is required by OAuth providers like Gong that expect credentials ONLY in the 138 Authorization header and reject requests that include them in both places. 139 """ 140 # Check if credentials are being sent via Authorization header 141 headers = self.get_refresh_request_headers() 142 credentials_in_header = headers and "Authorization" in headers 143 144 # Only include client credentials in body if not already in header 145 include_client_credentials = not credentials_in_header 146 147 payload: MutableMapping[str, Any] = { 148 self.get_grant_type_name(): self.get_grant_type(), 149 } 150 151 # Only include client credentials in body if configured to do so and not in header 152 if include_client_credentials: 153 payload[self.get_client_id_name()] = self.get_client_id() 154 payload[self.get_client_secret_name()] = self.get_client_secret() 155 156 payload[self.get_refresh_token_name()] = self.get_refresh_token() 157 158 if self.get_scopes(): 159 payload["scopes"] = self.get_scopes() 160 161 if self.get_refresh_request_body(): 162 for key, val in self.get_refresh_request_body().items(): 163 # We defer to existing oauth constructs over custom configured fields 164 if key not in payload: 165 payload[key] = val 166 167 return payload 168 169 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 170 """ 171 Returns the request headers to set on the refresh request 172 173 """ 174 headers = self.get_refresh_request_headers() 175 return headers if headers else None 176 177 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 178 """ 179 Returns the refresh token and its expiration datetime 180 181 :return: a tuple of (access_token, token_lifespan) 182 """ 183 try: 184 response_json = self._make_handled_request() 185 except ( 186 requests.exceptions.ConnectionError, 187 requests.exceptions.ConnectTimeout, 188 requests.exceptions.ReadTimeout, 189 ) as e: 190 raise AirbyteTracedException( 191 message="OAuth access token refresh request failed due to a network error.", 192 internal_message=f"Network error during OAuth token refresh after retries were exhausted: {e}", 193 failure_type=FailureType.transient_error, 194 ) from e 195 self._ensure_access_token_in_response(response_json) 196 197 return ( 198 self._extract_access_token(response_json), 199 self._extract_token_expiry_date(response_json), 200 ) 201 202 # ---------------- 203 # PRIVATE METHODS 204 # ---------------- 205 206 def _default_token_expiry_date(self) -> AirbyteDateTime: 207 """ 208 Returns the default token expiry date 209 """ 210 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 211 default_token_expiry_duration_hours = 1 # 1 hour 212 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 213 214 def _wrap_refresh_token_exception( 215 self, exception: requests.exceptions.RequestException 216 ) -> bool: 217 """ 218 Wraps and handles exceptions that occur during the refresh token process. 219 220 This method checks if the provided exception is related to a refresh token error 221 by examining the response status code and specific error content. 222 223 Args: 224 exception (requests.exceptions.RequestException): The exception raised during the request. 225 226 Returns: 227 bool: True if the exception is related to a refresh token error, False otherwise. 228 """ 229 try: 230 if exception.response is not None: 231 exception_content = exception.response.json() 232 else: 233 return False 234 except JSONDecodeError: 235 return False 236 return ( 237 exception.response.status_code in self._refresh_token_error_status_codes 238 and exception_content.get(self._refresh_token_error_key) 239 in self._refresh_token_error_values 240 ) 241 242 @backoff.on_exception( 243 backoff.expo, 244 ( 245 DefaultBackoffException, 246 requests.exceptions.ConnectionError, 247 requests.exceptions.ConnectTimeout, 248 requests.exceptions.ReadTimeout, 249 ), 250 on_backoff=lambda details: logger.info( 251 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 252 ), 253 max_time=300, 254 ) 255 def _make_handled_request(self) -> Any: 256 """ 257 Makes a handled HTTP request to refresh an OAuth token. 258 259 This method sends a POST request to the token refresh endpoint with the necessary 260 headers and body to obtain a new access token. It handles various exceptions that 261 may occur during the request and logs the response for troubleshooting purposes. 262 263 Returns: 264 Mapping[str, Any]: The JSON response from the token refresh endpoint. 265 266 Raises: 267 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 268 or any 5xx server error. 269 AirbyteTracedException: If the refresh token is invalid or expired, prompting 270 re-authentication. 271 Exception: For any other exceptions that occur during the request. 272 """ 273 try: 274 response = requests.request( 275 method="POST", 276 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 277 data=self.build_refresh_request_body(), 278 headers=self.build_refresh_request_headers(), 279 ) 280 281 if not response.ok: 282 # log the response even if the request failed for troubleshooting purposes 283 self._log_response(response) 284 response.raise_for_status() 285 286 response_json = response.json() 287 288 try: 289 # extract the access token and add to secrets to avoid logging the raw value 290 access_key = self._extract_access_token(response_json) 291 if access_key: 292 add_to_secrets(access_key) 293 except ResponseKeysMaxRecurtionReached as e: 294 # could not find the access token in the response, so do nothing 295 pass 296 297 self._log_response(response) 298 299 return response_json 300 except requests.exceptions.RequestException as e: 301 if e.response is not None: 302 if e.response.status_code == 429 or e.response.status_code >= 500: 303 raise DefaultBackoffException( 304 request=e.response.request, 305 response=e.response, 306 failure_type=FailureType.transient_error, 307 ) 308 if self._wrap_refresh_token_exception(e): 309 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 310 raise AirbyteTracedException( 311 internal_message=message, message=message, failure_type=FailureType.config_error 312 ) 313 raise 314 except Exception as e: 315 raise AirbyteTracedException( 316 message="OAuth access token refresh request failed.", 317 internal_message=f"Unexpected error during OAuth token refresh: {e}", 318 failure_type=FailureType.system_error, 319 ) from e 320 321 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 322 """ 323 Ensures that the access token is present in the response data. 324 325 This method attempts to extract the access token from the provided response data. 326 If the access token is not found, it raises an exception indicating that the token 327 refresh API response was missing the access token. 328 329 Args: 330 response_data (Mapping[str, Any]): The response data from which to extract the access token. 331 332 Raises: 333 Exception: If the access token is not found in the response data. 334 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 335 """ 336 try: 337 access_key = self._extract_access_token(response_data) 338 if not access_key: 339 raise Exception( 340 f"Token refresh API response was missing access token {self.get_access_token_name()}" 341 ) 342 except ResponseKeysMaxRecurtionReached as e: 343 raise e 344 345 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 346 """ 347 Parse a string or integer token expiration date into a datetime object 348 349 :return: expiration datetime 350 """ 351 if self.token_expiry_is_time_of_expiration: 352 if not self.token_expiry_date_format: 353 raise ValueError( 354 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 355 ) 356 try: 357 return ab_datetime_parse(str(value)) 358 except ValueError as e: 359 raise ValueError(f"Invalid token expiry date format: {e}") 360 else: 361 try: 362 # Only accept numeric values (as int/float/string) when no format specified 363 seconds = int(float(str(value))) 364 return ab_datetime_now() + timedelta(seconds=seconds) 365 except (ValueError, TypeError): 366 raise ValueError( 367 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 368 ) 369 370 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 371 """ 372 Extracts the access token from the given response data. 373 374 Args: 375 response_data (Mapping[str, Any]): The response data from which to extract the access token. 376 377 Returns: 378 str: The extracted access token. 379 """ 380 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 381 382 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 383 """ 384 Extracts the refresh token from the given response data. 385 386 Args: 387 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 388 389 Returns: 390 str: The extracted refresh token. 391 """ 392 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 393 394 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 395 """ 396 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 397 398 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 399 400 Args: 401 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 402 403 Returns: 404 The extracted token_expiry_date or None if not found. 405 """ 406 expires_in = self._find_and_get_value_from_response( 407 response_data, self.get_expires_in_name() 408 ) 409 if expires_in is not None: 410 return self._parse_token_expiration_date(expires_in) 411 412 # expires_in is None 413 existing_expiry_date = self.get_token_expiry_date() 414 if existing_expiry_date and not self.token_has_expired(): 415 return existing_expiry_date 416 417 return self._default_token_expiry_date() 418 419 def _find_and_get_value_from_response( 420 self, 421 response_data: Mapping[str, Any], 422 key_name: str, 423 max_depth: int = 5, 424 current_depth: int = 0, 425 ) -> Any: 426 """ 427 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 428 429 Args: 430 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 431 key_name (str): The key to search for in the response data. 432 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 433 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 434 435 Returns: 436 Any: The value associated with the specified key if found, otherwise None. 437 438 Raises: 439 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 440 """ 441 if current_depth > max_depth: 442 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 443 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 444 raise ResponseKeysMaxRecurtionReached( 445 internal_message=message, message=message, failure_type=FailureType.config_error 446 ) 447 448 if isinstance(response_data, dict): 449 # get from the root level 450 if key_name in response_data: 451 return response_data[key_name] 452 453 # get from the nested object 454 for _, value in response_data.items(): 455 result = self._find_and_get_value_from_response( 456 value, key_name, max_depth, current_depth + 1 457 ) 458 if result is not None: 459 return result 460 461 # get from the nested array object 462 elif isinstance(response_data, list): 463 for item in response_data: 464 result = self._find_and_get_value_from_response( 465 item, key_name, max_depth, current_depth + 1 466 ) 467 if result is not None: 468 return result 469 470 return None 471 472 @property 473 def _message_repository(self) -> Optional[MessageRepository]: 474 """ 475 The implementation can define a message_repository if it wants debugging logs for HTTP requests 476 """ 477 return _NOOP_MESSAGE_REPOSITORY 478 479 def _log_response(self, response: requests.Response) -> None: 480 """ 481 Logs the HTTP response using the message repository if it is available. 482 483 Args: 484 response (requests.Response): The HTTP response to log. 485 """ 486 if self._message_repository: 487 self._message_repository.log_message( 488 Level.DEBUG, 489 lambda: format_http_message( 490 response, 491 "Refresh token", 492 "Obtains access token", 493 self._NO_STREAM_NAME, 494 is_auxiliary=True, 495 type="AUTH", 496 ), 497 ) 498 499 # ---------------- 500 # ABSTR METHODS 501 # ---------------- 502 503 @abstractmethod 504 def get_token_refresh_endpoint(self) -> Optional[str]: 505 """Returns the endpoint to refresh the access token""" 506 507 @abstractmethod 508 def get_client_id_name(self) -> str: 509 """The client id name to authenticate""" 510 511 @abstractmethod 512 def get_client_id(self) -> str: 513 """The client id to authenticate""" 514 515 @abstractmethod 516 def get_client_secret_name(self) -> str: 517 """The client secret name to authenticate""" 518 519 @abstractmethod 520 def get_client_secret(self) -> str: 521 """The client secret to authenticate""" 522 523 @abstractmethod 524 def get_refresh_token_name(self) -> str: 525 """The refresh token name to authenticate""" 526 527 @abstractmethod 528 def get_refresh_token(self) -> Optional[str]: 529 """The token used to refresh the access token when it expires""" 530 531 @abstractmethod 532 def get_scopes(self) -> List[str]: 533 """List of requested scopes""" 534 535 @abstractmethod 536 def get_token_expiry_date(self) -> AirbyteDateTime: 537 """Expiration date of the access token""" 538 539 @abstractmethod 540 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 541 """Setter for access token expiration date""" 542 543 @abstractmethod 544 def get_access_token_name(self) -> str: 545 """Field to extract access token from in the response""" 546 547 @abstractmethod 548 def get_expires_in_name(self) -> str: 549 """Returns the expires_in field name""" 550 551 @abstractmethod 552 def get_refresh_request_body(self) -> Mapping[str, Any]: 553 """Returns the request body to set on the refresh request""" 554 555 @abstractmethod 556 def get_refresh_request_headers(self) -> Mapping[str, Any]: 557 """Returns the request headers to set on the refresh request""" 558 559 @abstractmethod 560 def get_grant_type(self) -> str: 561 """Returns grant_type specified for requesting access_token""" 562 563 @abstractmethod 564 def get_grant_type_name(self) -> str: 565 """Returns grant_type specified name for requesting access_token""" 566 567 @property 568 @abstractmethod 569 def access_token(self) -> str: 570 """Returns the access token""" 571 572 @access_token.setter 573 @abstractmethod 574 def access_token(self, value: str) -> str: 575 """Setter for the access token"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.
53 def __init__( 54 self, 55 refresh_token_error_status_codes: Tuple[int, ...] = (), 56 refresh_token_error_key: str = "", 57 refresh_token_error_values: Tuple[str, ...] = (), 58 ) -> None: 59 """ 60 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 61 then http errors with such params will be wrapped in AirbyteTracedException. 62 """ 63 self._refresh_token_error_status_codes = refresh_token_error_status_codes 64 self._refresh_token_error_key = refresh_token_error_key 65 self._refresh_token_error_values = refresh_token_error_values
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.
76 @property 77 def token_expiry_is_time_of_expiration(self) -> bool: 78 """ 79 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 80 """ 81 82 return False
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
84 @property 85 def token_expiry_date_format(self) -> Optional[str]: 86 """ 87 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 88 """ 89 90 return None
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
92 def get_auth_header(self) -> Mapping[str, Any]: 93 """HTTP header to set on the requests""" 94 token = self.access_token if self._is_access_token_flow else self.get_access_token() 95 return {"Authorization": f"Bearer {token}"}
HTTP header to set on the requests
97 def get_access_token(self) -> str: 98 """ 99 Returns the access token. 100 101 This method uses double-checked locking to ensure thread-safe token refresh. 102 When multiple threads (streams) detect an expired token simultaneously, only one 103 will perform the refresh while others wait. After acquiring the lock, the token 104 expiry is re-checked to avoid redundant refresh attempts. 105 """ 106 if self.token_has_expired(): 107 with self._token_refresh_lock: 108 # Double-check after acquiring lock - another thread may have already refreshed 109 if self.token_has_expired(): 110 self.refresh_and_set_access_token() 111 112 return self.access_token
Returns the access token.
This method uses double-checked locking to ensure thread-safe token refresh. When multiple threads (streams) detect an expired token simultaneously, only one will perform the refresh while others wait. After acquiring the lock, the token expiry is re-checked to avoid redundant refresh attempts.
114 def refresh_and_set_access_token(self) -> None: 115 """Force refresh the access token and update internal state. 116 117 This method refreshes the access token regardless of whether it has expired, 118 and updates the internal token and expiry date. Subclasses may override this 119 to handle additional state updates (e.g., persisting new refresh tokens). 120 """ 121 token, expires_in = self.refresh_access_token() 122 self.access_token = token 123 self.set_token_expiry_date(expires_in)
Force refresh the access token and update internal state.
This method refreshes the access token regardless of whether it has expired, and updates the internal token and expiry date. Subclasses may override this to handle additional state updates (e.g., persisting new refresh tokens).
125 def token_has_expired(self) -> bool: 126 """Returns True if the token is expired""" 127 return ab_datetime_now() > self.get_token_expiry_date()
Returns True if the token is expired
129 def build_refresh_request_body(self) -> Mapping[str, Any]: 130 """ 131 Returns the request body to set on the refresh request. 132 133 Override to define additional parameters. 134 135 Client credentials (client_id and client_secret) are excluded from the body when 136 refresh_request_headers contains an Authorization header (e.g., Basic auth). 137 This is required by OAuth providers like Gong that expect credentials ONLY in the 138 Authorization header and reject requests that include them in both places. 139 """ 140 # Check if credentials are being sent via Authorization header 141 headers = self.get_refresh_request_headers() 142 credentials_in_header = headers and "Authorization" in headers 143 144 # Only include client credentials in body if not already in header 145 include_client_credentials = not credentials_in_header 146 147 payload: MutableMapping[str, Any] = { 148 self.get_grant_type_name(): self.get_grant_type(), 149 } 150 151 # Only include client credentials in body if configured to do so and not in header 152 if include_client_credentials: 153 payload[self.get_client_id_name()] = self.get_client_id() 154 payload[self.get_client_secret_name()] = self.get_client_secret() 155 156 payload[self.get_refresh_token_name()] = self.get_refresh_token() 157 158 if self.get_scopes(): 159 payload["scopes"] = self.get_scopes() 160 161 if self.get_refresh_request_body(): 162 for key, val in self.get_refresh_request_body().items(): 163 # We defer to existing oauth constructs over custom configured fields 164 if key not in payload: 165 payload[key] = val 166 167 return payload
Returns the request body to set on the refresh request.
Override to define additional parameters.
Client credentials (client_id and client_secret) are excluded from the body when refresh_request_headers contains an Authorization header (e.g., Basic auth). This is required by OAuth providers like Gong that expect credentials ONLY in the Authorization header and reject requests that include them in both places.
169 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 170 """ 171 Returns the request headers to set on the refresh request 172 173 """ 174 headers = self.get_refresh_request_headers() 175 return headers if headers else None
Returns the request headers to set on the refresh request
177 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 178 """ 179 Returns the refresh token and its expiration datetime 180 181 :return: a tuple of (access_token, token_lifespan) 182 """ 183 try: 184 response_json = self._make_handled_request() 185 except ( 186 requests.exceptions.ConnectionError, 187 requests.exceptions.ConnectTimeout, 188 requests.exceptions.ReadTimeout, 189 ) as e: 190 raise AirbyteTracedException( 191 message="OAuth access token refresh request failed due to a network error.", 192 internal_message=f"Network error during OAuth token refresh after retries were exhausted: {e}", 193 failure_type=FailureType.transient_error, 194 ) from e 195 self._ensure_access_token_in_response(response_json) 196 197 return ( 198 self._extract_access_token(response_json), 199 self._extract_token_expiry_date(response_json), 200 )
Returns the refresh token and its expiration datetime
Returns
a tuple of (access_token, token_lifespan)
503 @abstractmethod 504 def get_token_refresh_endpoint(self) -> Optional[str]: 505 """Returns the endpoint to refresh the access token"""
Returns the endpoint to refresh the access token
507 @abstractmethod 508 def get_client_id_name(self) -> str: 509 """The client id name to authenticate"""
The client id name to authenticate
515 @abstractmethod 516 def get_client_secret_name(self) -> str: 517 """The client secret name to authenticate"""
The client secret name to authenticate
519 @abstractmethod 520 def get_client_secret(self) -> str: 521 """The client secret to authenticate"""
The client secret to authenticate
523 @abstractmethod 524 def get_refresh_token_name(self) -> str: 525 """The refresh token name to authenticate"""
The refresh token name to authenticate
527 @abstractmethod 528 def get_refresh_token(self) -> Optional[str]: 529 """The token used to refresh the access token when it expires"""
The token used to refresh the access token when it expires
535 @abstractmethod 536 def get_token_expiry_date(self) -> AirbyteDateTime: 537 """Expiration date of the access token"""
Expiration date of the access token
539 @abstractmethod 540 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 541 """Setter for access token expiration date"""
Setter for access token expiration date
543 @abstractmethod 544 def get_access_token_name(self) -> str: 545 """Field to extract access token from in the response"""
Field to extract access token from in the response
547 @abstractmethod 548 def get_expires_in_name(self) -> str: 549 """Returns the expires_in field name"""
Returns the expires_in field name
551 @abstractmethod 552 def get_refresh_request_body(self) -> Mapping[str, Any]: 553 """Returns the request body to set on the refresh request"""
Returns the request body to set on the refresh request
555 @abstractmethod 556 def get_refresh_request_headers(self) -> Mapping[str, Any]: 557 """Returns the request headers to set on the refresh request"""
Returns the request headers to set on the refresh request
559 @abstractmethod 560 def get_grant_type(self) -> str: 561 """Returns grant_type specified for requesting access_token"""
Returns grant_type specified for requesting access_token