airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth
1# 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 3# 4 5import logging 6import threading 7from abc import abstractmethod 8from datetime import timedelta 9from json import JSONDecodeError 10from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union 11 12import backoff 13import requests 14from requests.auth import AuthBase 15 16from airbyte_cdk.models import FailureType, Level 17from airbyte_cdk.sources.http_logger import format_http_message 18from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository 19from airbyte_cdk.utils import AirbyteTracedException 20from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets 21from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse 22 23from ..exceptions import DefaultBackoffException 24 25logger = logging.getLogger("airbyte") 26_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() 27 28 29class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 30 """ 31 Raised when the max level of recursion is reached, when trying to 32 find-and-get the target key, during the `_make_handled_request` 33 """ 34 35 36class AbstractOauth2Authenticator(AuthBase): 37 """ 38 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 39 is designed to generically perform the refresh flow without regard to how config fields are get/set by 40 delegating that behavior to the classes implementing the interface. 41 """ 42 43 _NO_STREAM_NAME = None 44 45 # Class-level lock to prevent concurrent token refresh across multiple authenticator instances. 46 # This is necessary because multiple streams may share the same OAuth credentials (refresh token) 47 # through the connector config. Without this lock, concurrent refresh attempts can cause race 48 # conditions where one stream successfully refreshes the token while others fail because the 49 # refresh token has been invalidated (especially for single-use refresh tokens). 50 _token_refresh_lock: threading.Lock = threading.Lock() 51 52 def __init__( 53 self, 54 refresh_token_error_status_codes: Tuple[int, ...] = (), 55 refresh_token_error_key: str = "", 56 refresh_token_error_values: Tuple[str, ...] = (), 57 ) -> None: 58 """ 59 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 60 then http errors with such params will be wrapped in AirbyteTracedException. 61 """ 62 self._refresh_token_error_status_codes = refresh_token_error_status_codes 63 self._refresh_token_error_key = refresh_token_error_key 64 self._refresh_token_error_values = refresh_token_error_values 65 66 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 67 """Attach the HTTP headers required to authenticate on the HTTP request""" 68 request.headers.update(self.get_auth_header()) 69 return request 70 71 @property 72 def _is_access_token_flow(self) -> bool: 73 return self.get_token_refresh_endpoint() is None and self.access_token is not None 74 75 @property 76 def token_expiry_is_time_of_expiration(self) -> bool: 77 """ 78 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 79 """ 80 81 return False 82 83 @property 84 def token_expiry_date_format(self) -> Optional[str]: 85 """ 86 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 87 """ 88 89 return None 90 91 def get_auth_header(self) -> Mapping[str, Any]: 92 """HTTP header to set on the requests""" 93 token = self.access_token if self._is_access_token_flow else self.get_access_token() 94 return {"Authorization": f"Bearer {token}"} 95 96 def get_access_token(self) -> str: 97 """ 98 Returns the access token. 99 100 This method uses double-checked locking to ensure thread-safe token refresh. 101 When multiple threads (streams) detect an expired token simultaneously, only one 102 will perform the refresh while others wait. After acquiring the lock, the token 103 expiry is re-checked to avoid redundant refresh attempts. 104 """ 105 if self.token_has_expired(): 106 with self._token_refresh_lock: 107 # Double-check after acquiring lock - another thread may have already refreshed 108 if self.token_has_expired(): 109 self.refresh_and_set_access_token() 110 111 return self.access_token 112 113 def refresh_and_set_access_token(self) -> None: 114 """Force refresh the access token and update internal state. 115 116 This method refreshes the access token regardless of whether it has expired, 117 and updates the internal token and expiry date. Subclasses may override this 118 to handle additional state updates (e.g., persisting new refresh tokens). 119 """ 120 token, expires_in = self.refresh_access_token() 121 self.access_token = token 122 self.set_token_expiry_date(expires_in) 123 124 def token_has_expired(self) -> bool: 125 """Returns True if the token is expired""" 126 return ab_datetime_now() > self.get_token_expiry_date() 127 128 def build_refresh_request_body(self) -> Mapping[str, Any]: 129 """ 130 Returns the request body to set on the refresh request. 131 132 Override to define additional parameters. 133 134 Client credentials (client_id and client_secret) are excluded from the body when 135 refresh_request_headers contains an Authorization header (e.g., Basic auth). 136 This is required by OAuth providers like Gong that expect credentials ONLY in the 137 Authorization header and reject requests that include them in both places. 138 """ 139 # Check if credentials are being sent via Authorization header 140 headers = self.get_refresh_request_headers() 141 credentials_in_header = headers and "Authorization" in headers 142 143 # Only include client credentials in body if not already in header 144 include_client_credentials = not credentials_in_header 145 146 payload: MutableMapping[str, Any] = { 147 self.get_grant_type_name(): self.get_grant_type(), 148 } 149 150 # Only include client credentials in body if configured to do so and not in header 151 if include_client_credentials: 152 payload[self.get_client_id_name()] = self.get_client_id() 153 payload[self.get_client_secret_name()] = self.get_client_secret() 154 155 payload[self.get_refresh_token_name()] = self.get_refresh_token() 156 157 if self.get_scopes(): 158 payload["scopes"] = self.get_scopes() 159 160 if self.get_refresh_request_body(): 161 for key, val in self.get_refresh_request_body().items(): 162 # We defer to existing oauth constructs over custom configured fields 163 if key not in payload: 164 payload[key] = val 165 166 return payload 167 168 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 169 """ 170 Returns the request headers to set on the refresh request 171 172 """ 173 headers = self.get_refresh_request_headers() 174 return headers if headers else None 175 176 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 177 """ 178 Returns the refresh token and its expiration datetime 179 180 :return: a tuple of (access_token, token_lifespan) 181 """ 182 response_json = self._make_handled_request() 183 self._ensure_access_token_in_response(response_json) 184 185 return ( 186 self._extract_access_token(response_json), 187 self._extract_token_expiry_date(response_json), 188 ) 189 190 # ---------------- 191 # PRIVATE METHODS 192 # ---------------- 193 194 def _default_token_expiry_date(self) -> AirbyteDateTime: 195 """ 196 Returns the default token expiry date 197 """ 198 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 199 default_token_expiry_duration_hours = 1 # 1 hour 200 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 201 202 def _wrap_refresh_token_exception( 203 self, exception: requests.exceptions.RequestException 204 ) -> bool: 205 """ 206 Wraps and handles exceptions that occur during the refresh token process. 207 208 This method checks if the provided exception is related to a refresh token error 209 by examining the response status code and specific error content. 210 211 Args: 212 exception (requests.exceptions.RequestException): The exception raised during the request. 213 214 Returns: 215 bool: True if the exception is related to a refresh token error, False otherwise. 216 """ 217 try: 218 if exception.response is not None: 219 exception_content = exception.response.json() 220 else: 221 return False 222 except JSONDecodeError: 223 return False 224 return ( 225 exception.response.status_code in self._refresh_token_error_status_codes 226 and exception_content.get(self._refresh_token_error_key) 227 in self._refresh_token_error_values 228 ) 229 230 @backoff.on_exception( 231 backoff.expo, 232 DefaultBackoffException, 233 on_backoff=lambda details: logger.info( 234 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 235 ), 236 max_time=300, 237 ) 238 def _make_handled_request(self) -> Any: 239 """ 240 Makes a handled HTTP request to refresh an OAuth token. 241 242 This method sends a POST request to the token refresh endpoint with the necessary 243 headers and body to obtain a new access token. It handles various exceptions that 244 may occur during the request and logs the response for troubleshooting purposes. 245 246 Returns: 247 Mapping[str, Any]: The JSON response from the token refresh endpoint. 248 249 Raises: 250 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 251 or any 5xx server error. 252 AirbyteTracedException: If the refresh token is invalid or expired, prompting 253 re-authentication. 254 Exception: For any other exceptions that occur during the request. 255 """ 256 try: 257 response = requests.request( 258 method="POST", 259 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 260 data=self.build_refresh_request_body(), 261 headers=self.build_refresh_request_headers(), 262 ) 263 264 if not response.ok: 265 # log the response even if the request failed for troubleshooting purposes 266 self._log_response(response) 267 response.raise_for_status() 268 269 response_json = response.json() 270 271 try: 272 # extract the access token and add to secrets to avoid logging the raw value 273 access_key = self._extract_access_token(response_json) 274 if access_key: 275 add_to_secrets(access_key) 276 except ResponseKeysMaxRecurtionReached as e: 277 # could not find the access token in the response, so do nothing 278 pass 279 280 self._log_response(response) 281 282 return response_json 283 except requests.exceptions.RequestException as e: 284 if e.response is not None: 285 if e.response.status_code == 429 or e.response.status_code >= 500: 286 raise DefaultBackoffException( 287 request=e.response.request, 288 response=e.response, 289 failure_type=FailureType.transient_error, 290 ) 291 if self._wrap_refresh_token_exception(e): 292 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 293 raise AirbyteTracedException( 294 internal_message=message, message=message, failure_type=FailureType.config_error 295 ) 296 raise 297 except Exception as e: 298 raise Exception(f"Error while refreshing access token: {e}") from e 299 300 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 301 """ 302 Ensures that the access token is present in the response data. 303 304 This method attempts to extract the access token from the provided response data. 305 If the access token is not found, it raises an exception indicating that the token 306 refresh API response was missing the access token. 307 308 Args: 309 response_data (Mapping[str, Any]): The response data from which to extract the access token. 310 311 Raises: 312 Exception: If the access token is not found in the response data. 313 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 314 """ 315 try: 316 access_key = self._extract_access_token(response_data) 317 if not access_key: 318 raise Exception( 319 f"Token refresh API response was missing access token {self.get_access_token_name()}" 320 ) 321 except ResponseKeysMaxRecurtionReached as e: 322 raise e 323 324 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 325 """ 326 Parse a string or integer token expiration date into a datetime object 327 328 :return: expiration datetime 329 """ 330 if self.token_expiry_is_time_of_expiration: 331 if not self.token_expiry_date_format: 332 raise ValueError( 333 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 334 ) 335 try: 336 return ab_datetime_parse(str(value)) 337 except ValueError as e: 338 raise ValueError(f"Invalid token expiry date format: {e}") 339 else: 340 try: 341 # Only accept numeric values (as int/float/string) when no format specified 342 seconds = int(float(str(value))) 343 return ab_datetime_now() + timedelta(seconds=seconds) 344 except (ValueError, TypeError): 345 raise ValueError( 346 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 347 ) 348 349 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 350 """ 351 Extracts the access token from the given response data. 352 353 Args: 354 response_data (Mapping[str, Any]): The response data from which to extract the access token. 355 356 Returns: 357 str: The extracted access token. 358 """ 359 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 360 361 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 362 """ 363 Extracts the refresh token from the given response data. 364 365 Args: 366 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 367 368 Returns: 369 str: The extracted refresh token. 370 """ 371 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 372 373 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 374 """ 375 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 376 377 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 378 379 Args: 380 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 381 382 Returns: 383 The extracted token_expiry_date or None if not found. 384 """ 385 expires_in = self._find_and_get_value_from_response( 386 response_data, self.get_expires_in_name() 387 ) 388 if expires_in is not None: 389 return self._parse_token_expiration_date(expires_in) 390 391 # expires_in is None 392 existing_expiry_date = self.get_token_expiry_date() 393 if existing_expiry_date and not self.token_has_expired(): 394 return existing_expiry_date 395 396 return self._default_token_expiry_date() 397 398 def _find_and_get_value_from_response( 399 self, 400 response_data: Mapping[str, Any], 401 key_name: str, 402 max_depth: int = 5, 403 current_depth: int = 0, 404 ) -> Any: 405 """ 406 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 407 408 Args: 409 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 410 key_name (str): The key to search for in the response data. 411 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 412 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 413 414 Returns: 415 Any: The value associated with the specified key if found, otherwise None. 416 417 Raises: 418 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 419 """ 420 if current_depth > max_depth: 421 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 422 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 423 raise ResponseKeysMaxRecurtionReached( 424 internal_message=message, message=message, failure_type=FailureType.config_error 425 ) 426 427 if isinstance(response_data, dict): 428 # get from the root level 429 if key_name in response_data: 430 return response_data[key_name] 431 432 # get from the nested object 433 for _, value in response_data.items(): 434 result = self._find_and_get_value_from_response( 435 value, key_name, max_depth, current_depth + 1 436 ) 437 if result is not None: 438 return result 439 440 # get from the nested array object 441 elif isinstance(response_data, list): 442 for item in response_data: 443 result = self._find_and_get_value_from_response( 444 item, key_name, max_depth, current_depth + 1 445 ) 446 if result is not None: 447 return result 448 449 return None 450 451 @property 452 def _message_repository(self) -> Optional[MessageRepository]: 453 """ 454 The implementation can define a message_repository if it wants debugging logs for HTTP requests 455 """ 456 return _NOOP_MESSAGE_REPOSITORY 457 458 def _log_response(self, response: requests.Response) -> None: 459 """ 460 Logs the HTTP response using the message repository if it is available. 461 462 Args: 463 response (requests.Response): The HTTP response to log. 464 """ 465 if self._message_repository: 466 self._message_repository.log_message( 467 Level.DEBUG, 468 lambda: format_http_message( 469 response, 470 "Refresh token", 471 "Obtains access token", 472 self._NO_STREAM_NAME, 473 is_auxiliary=True, 474 type="AUTH", 475 ), 476 ) 477 478 # ---------------- 479 # ABSTR METHODS 480 # ---------------- 481 482 @abstractmethod 483 def get_token_refresh_endpoint(self) -> Optional[str]: 484 """Returns the endpoint to refresh the access token""" 485 486 @abstractmethod 487 def get_client_id_name(self) -> str: 488 """The client id name to authenticate""" 489 490 @abstractmethod 491 def get_client_id(self) -> str: 492 """The client id to authenticate""" 493 494 @abstractmethod 495 def get_client_secret_name(self) -> str: 496 """The client secret name to authenticate""" 497 498 @abstractmethod 499 def get_client_secret(self) -> str: 500 """The client secret to authenticate""" 501 502 @abstractmethod 503 def get_refresh_token_name(self) -> str: 504 """The refresh token name to authenticate""" 505 506 @abstractmethod 507 def get_refresh_token(self) -> Optional[str]: 508 """The token used to refresh the access token when it expires""" 509 510 @abstractmethod 511 def get_scopes(self) -> List[str]: 512 """List of requested scopes""" 513 514 @abstractmethod 515 def get_token_expiry_date(self) -> AirbyteDateTime: 516 """Expiration date of the access token""" 517 518 @abstractmethod 519 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 520 """Setter for access token expiration date""" 521 522 @abstractmethod 523 def get_access_token_name(self) -> str: 524 """Field to extract access token from in the response""" 525 526 @abstractmethod 527 def get_expires_in_name(self) -> str: 528 """Returns the expires_in field name""" 529 530 @abstractmethod 531 def get_refresh_request_body(self) -> Mapping[str, Any]: 532 """Returns the request body to set on the refresh request""" 533 534 @abstractmethod 535 def get_refresh_request_headers(self) -> Mapping[str, Any]: 536 """Returns the request headers to set on the refresh request""" 537 538 @abstractmethod 539 def get_grant_type(self) -> str: 540 """Returns grant_type specified for requesting access_token""" 541 542 @abstractmethod 543 def get_grant_type_name(self) -> str: 544 """Returns grant_type specified name for requesting access_token""" 545 546 @property 547 @abstractmethod 548 def access_token(self) -> str: 549 """Returns the access token""" 550 551 @access_token.setter 552 @abstractmethod 553 def access_token(self, value: str) -> str: 554 """Setter for the access token"""
30class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 31 """ 32 Raised when the max level of recursion is reached, when trying to 33 find-and-get the target key, during the `_make_handled_request` 34 """
Raised when the max level of recursion is reached, when trying to
find-and-get the target key, during the _make_handled_request
37class AbstractOauth2Authenticator(AuthBase): 38 """ 39 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 40 is designed to generically perform the refresh flow without regard to how config fields are get/set by 41 delegating that behavior to the classes implementing the interface. 42 """ 43 44 _NO_STREAM_NAME = None 45 46 # Class-level lock to prevent concurrent token refresh across multiple authenticator instances. 47 # This is necessary because multiple streams may share the same OAuth credentials (refresh token) 48 # through the connector config. Without this lock, concurrent refresh attempts can cause race 49 # conditions where one stream successfully refreshes the token while others fail because the 50 # refresh token has been invalidated (especially for single-use refresh tokens). 51 _token_refresh_lock: threading.Lock = threading.Lock() 52 53 def __init__( 54 self, 55 refresh_token_error_status_codes: Tuple[int, ...] = (), 56 refresh_token_error_key: str = "", 57 refresh_token_error_values: Tuple[str, ...] = (), 58 ) -> None: 59 """ 60 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 61 then http errors with such params will be wrapped in AirbyteTracedException. 62 """ 63 self._refresh_token_error_status_codes = refresh_token_error_status_codes 64 self._refresh_token_error_key = refresh_token_error_key 65 self._refresh_token_error_values = refresh_token_error_values 66 67 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 68 """Attach the HTTP headers required to authenticate on the HTTP request""" 69 request.headers.update(self.get_auth_header()) 70 return request 71 72 @property 73 def _is_access_token_flow(self) -> bool: 74 return self.get_token_refresh_endpoint() is None and self.access_token is not None 75 76 @property 77 def token_expiry_is_time_of_expiration(self) -> bool: 78 """ 79 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 80 """ 81 82 return False 83 84 @property 85 def token_expiry_date_format(self) -> Optional[str]: 86 """ 87 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 88 """ 89 90 return None 91 92 def get_auth_header(self) -> Mapping[str, Any]: 93 """HTTP header to set on the requests""" 94 token = self.access_token if self._is_access_token_flow else self.get_access_token() 95 return {"Authorization": f"Bearer {token}"} 96 97 def get_access_token(self) -> str: 98 """ 99 Returns the access token. 100 101 This method uses double-checked locking to ensure thread-safe token refresh. 102 When multiple threads (streams) detect an expired token simultaneously, only one 103 will perform the refresh while others wait. After acquiring the lock, the token 104 expiry is re-checked to avoid redundant refresh attempts. 105 """ 106 if self.token_has_expired(): 107 with self._token_refresh_lock: 108 # Double-check after acquiring lock - another thread may have already refreshed 109 if self.token_has_expired(): 110 self.refresh_and_set_access_token() 111 112 return self.access_token 113 114 def refresh_and_set_access_token(self) -> None: 115 """Force refresh the access token and update internal state. 116 117 This method refreshes the access token regardless of whether it has expired, 118 and updates the internal token and expiry date. Subclasses may override this 119 to handle additional state updates (e.g., persisting new refresh tokens). 120 """ 121 token, expires_in = self.refresh_access_token() 122 self.access_token = token 123 self.set_token_expiry_date(expires_in) 124 125 def token_has_expired(self) -> bool: 126 """Returns True if the token is expired""" 127 return ab_datetime_now() > self.get_token_expiry_date() 128 129 def build_refresh_request_body(self) -> Mapping[str, Any]: 130 """ 131 Returns the request body to set on the refresh request. 132 133 Override to define additional parameters. 134 135 Client credentials (client_id and client_secret) are excluded from the body when 136 refresh_request_headers contains an Authorization header (e.g., Basic auth). 137 This is required by OAuth providers like Gong that expect credentials ONLY in the 138 Authorization header and reject requests that include them in both places. 139 """ 140 # Check if credentials are being sent via Authorization header 141 headers = self.get_refresh_request_headers() 142 credentials_in_header = headers and "Authorization" in headers 143 144 # Only include client credentials in body if not already in header 145 include_client_credentials = not credentials_in_header 146 147 payload: MutableMapping[str, Any] = { 148 self.get_grant_type_name(): self.get_grant_type(), 149 } 150 151 # Only include client credentials in body if configured to do so and not in header 152 if include_client_credentials: 153 payload[self.get_client_id_name()] = self.get_client_id() 154 payload[self.get_client_secret_name()] = self.get_client_secret() 155 156 payload[self.get_refresh_token_name()] = self.get_refresh_token() 157 158 if self.get_scopes(): 159 payload["scopes"] = self.get_scopes() 160 161 if self.get_refresh_request_body(): 162 for key, val in self.get_refresh_request_body().items(): 163 # We defer to existing oauth constructs over custom configured fields 164 if key not in payload: 165 payload[key] = val 166 167 return payload 168 169 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 170 """ 171 Returns the request headers to set on the refresh request 172 173 """ 174 headers = self.get_refresh_request_headers() 175 return headers if headers else None 176 177 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 178 """ 179 Returns the refresh token and its expiration datetime 180 181 :return: a tuple of (access_token, token_lifespan) 182 """ 183 response_json = self._make_handled_request() 184 self._ensure_access_token_in_response(response_json) 185 186 return ( 187 self._extract_access_token(response_json), 188 self._extract_token_expiry_date(response_json), 189 ) 190 191 # ---------------- 192 # PRIVATE METHODS 193 # ---------------- 194 195 def _default_token_expiry_date(self) -> AirbyteDateTime: 196 """ 197 Returns the default token expiry date 198 """ 199 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 200 default_token_expiry_duration_hours = 1 # 1 hour 201 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 202 203 def _wrap_refresh_token_exception( 204 self, exception: requests.exceptions.RequestException 205 ) -> bool: 206 """ 207 Wraps and handles exceptions that occur during the refresh token process. 208 209 This method checks if the provided exception is related to a refresh token error 210 by examining the response status code and specific error content. 211 212 Args: 213 exception (requests.exceptions.RequestException): The exception raised during the request. 214 215 Returns: 216 bool: True if the exception is related to a refresh token error, False otherwise. 217 """ 218 try: 219 if exception.response is not None: 220 exception_content = exception.response.json() 221 else: 222 return False 223 except JSONDecodeError: 224 return False 225 return ( 226 exception.response.status_code in self._refresh_token_error_status_codes 227 and exception_content.get(self._refresh_token_error_key) 228 in self._refresh_token_error_values 229 ) 230 231 @backoff.on_exception( 232 backoff.expo, 233 DefaultBackoffException, 234 on_backoff=lambda details: logger.info( 235 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 236 ), 237 max_time=300, 238 ) 239 def _make_handled_request(self) -> Any: 240 """ 241 Makes a handled HTTP request to refresh an OAuth token. 242 243 This method sends a POST request to the token refresh endpoint with the necessary 244 headers and body to obtain a new access token. It handles various exceptions that 245 may occur during the request and logs the response for troubleshooting purposes. 246 247 Returns: 248 Mapping[str, Any]: The JSON response from the token refresh endpoint. 249 250 Raises: 251 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 252 or any 5xx server error. 253 AirbyteTracedException: If the refresh token is invalid or expired, prompting 254 re-authentication. 255 Exception: For any other exceptions that occur during the request. 256 """ 257 try: 258 response = requests.request( 259 method="POST", 260 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 261 data=self.build_refresh_request_body(), 262 headers=self.build_refresh_request_headers(), 263 ) 264 265 if not response.ok: 266 # log the response even if the request failed for troubleshooting purposes 267 self._log_response(response) 268 response.raise_for_status() 269 270 response_json = response.json() 271 272 try: 273 # extract the access token and add to secrets to avoid logging the raw value 274 access_key = self._extract_access_token(response_json) 275 if access_key: 276 add_to_secrets(access_key) 277 except ResponseKeysMaxRecurtionReached as e: 278 # could not find the access token in the response, so do nothing 279 pass 280 281 self._log_response(response) 282 283 return response_json 284 except requests.exceptions.RequestException as e: 285 if e.response is not None: 286 if e.response.status_code == 429 or e.response.status_code >= 500: 287 raise DefaultBackoffException( 288 request=e.response.request, 289 response=e.response, 290 failure_type=FailureType.transient_error, 291 ) 292 if self._wrap_refresh_token_exception(e): 293 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 294 raise AirbyteTracedException( 295 internal_message=message, message=message, failure_type=FailureType.config_error 296 ) 297 raise 298 except Exception as e: 299 raise Exception(f"Error while refreshing access token: {e}") from e 300 301 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 302 """ 303 Ensures that the access token is present in the response data. 304 305 This method attempts to extract the access token from the provided response data. 306 If the access token is not found, it raises an exception indicating that the token 307 refresh API response was missing the access token. 308 309 Args: 310 response_data (Mapping[str, Any]): The response data from which to extract the access token. 311 312 Raises: 313 Exception: If the access token is not found in the response data. 314 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 315 """ 316 try: 317 access_key = self._extract_access_token(response_data) 318 if not access_key: 319 raise Exception( 320 f"Token refresh API response was missing access token {self.get_access_token_name()}" 321 ) 322 except ResponseKeysMaxRecurtionReached as e: 323 raise e 324 325 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 326 """ 327 Parse a string or integer token expiration date into a datetime object 328 329 :return: expiration datetime 330 """ 331 if self.token_expiry_is_time_of_expiration: 332 if not self.token_expiry_date_format: 333 raise ValueError( 334 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 335 ) 336 try: 337 return ab_datetime_parse(str(value)) 338 except ValueError as e: 339 raise ValueError(f"Invalid token expiry date format: {e}") 340 else: 341 try: 342 # Only accept numeric values (as int/float/string) when no format specified 343 seconds = int(float(str(value))) 344 return ab_datetime_now() + timedelta(seconds=seconds) 345 except (ValueError, TypeError): 346 raise ValueError( 347 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 348 ) 349 350 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 351 """ 352 Extracts the access token from the given response data. 353 354 Args: 355 response_data (Mapping[str, Any]): The response data from which to extract the access token. 356 357 Returns: 358 str: The extracted access token. 359 """ 360 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 361 362 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 363 """ 364 Extracts the refresh token from the given response data. 365 366 Args: 367 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 368 369 Returns: 370 str: The extracted refresh token. 371 """ 372 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 373 374 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 375 """ 376 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 377 378 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 379 380 Args: 381 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 382 383 Returns: 384 The extracted token_expiry_date or None if not found. 385 """ 386 expires_in = self._find_and_get_value_from_response( 387 response_data, self.get_expires_in_name() 388 ) 389 if expires_in is not None: 390 return self._parse_token_expiration_date(expires_in) 391 392 # expires_in is None 393 existing_expiry_date = self.get_token_expiry_date() 394 if existing_expiry_date and not self.token_has_expired(): 395 return existing_expiry_date 396 397 return self._default_token_expiry_date() 398 399 def _find_and_get_value_from_response( 400 self, 401 response_data: Mapping[str, Any], 402 key_name: str, 403 max_depth: int = 5, 404 current_depth: int = 0, 405 ) -> Any: 406 """ 407 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 408 409 Args: 410 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 411 key_name (str): The key to search for in the response data. 412 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 413 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 414 415 Returns: 416 Any: The value associated with the specified key if found, otherwise None. 417 418 Raises: 419 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 420 """ 421 if current_depth > max_depth: 422 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 423 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 424 raise ResponseKeysMaxRecurtionReached( 425 internal_message=message, message=message, failure_type=FailureType.config_error 426 ) 427 428 if isinstance(response_data, dict): 429 # get from the root level 430 if key_name in response_data: 431 return response_data[key_name] 432 433 # get from the nested object 434 for _, value in response_data.items(): 435 result = self._find_and_get_value_from_response( 436 value, key_name, max_depth, current_depth + 1 437 ) 438 if result is not None: 439 return result 440 441 # get from the nested array object 442 elif isinstance(response_data, list): 443 for item in response_data: 444 result = self._find_and_get_value_from_response( 445 item, key_name, max_depth, current_depth + 1 446 ) 447 if result is not None: 448 return result 449 450 return None 451 452 @property 453 def _message_repository(self) -> Optional[MessageRepository]: 454 """ 455 The implementation can define a message_repository if it wants debugging logs for HTTP requests 456 """ 457 return _NOOP_MESSAGE_REPOSITORY 458 459 def _log_response(self, response: requests.Response) -> None: 460 """ 461 Logs the HTTP response using the message repository if it is available. 462 463 Args: 464 response (requests.Response): The HTTP response to log. 465 """ 466 if self._message_repository: 467 self._message_repository.log_message( 468 Level.DEBUG, 469 lambda: format_http_message( 470 response, 471 "Refresh token", 472 "Obtains access token", 473 self._NO_STREAM_NAME, 474 is_auxiliary=True, 475 type="AUTH", 476 ), 477 ) 478 479 # ---------------- 480 # ABSTR METHODS 481 # ---------------- 482 483 @abstractmethod 484 def get_token_refresh_endpoint(self) -> Optional[str]: 485 """Returns the endpoint to refresh the access token""" 486 487 @abstractmethod 488 def get_client_id_name(self) -> str: 489 """The client id name to authenticate""" 490 491 @abstractmethod 492 def get_client_id(self) -> str: 493 """The client id to authenticate""" 494 495 @abstractmethod 496 def get_client_secret_name(self) -> str: 497 """The client secret name to authenticate""" 498 499 @abstractmethod 500 def get_client_secret(self) -> str: 501 """The client secret to authenticate""" 502 503 @abstractmethod 504 def get_refresh_token_name(self) -> str: 505 """The refresh token name to authenticate""" 506 507 @abstractmethod 508 def get_refresh_token(self) -> Optional[str]: 509 """The token used to refresh the access token when it expires""" 510 511 @abstractmethod 512 def get_scopes(self) -> List[str]: 513 """List of requested scopes""" 514 515 @abstractmethod 516 def get_token_expiry_date(self) -> AirbyteDateTime: 517 """Expiration date of the access token""" 518 519 @abstractmethod 520 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 521 """Setter for access token expiration date""" 522 523 @abstractmethod 524 def get_access_token_name(self) -> str: 525 """Field to extract access token from in the response""" 526 527 @abstractmethod 528 def get_expires_in_name(self) -> str: 529 """Returns the expires_in field name""" 530 531 @abstractmethod 532 def get_refresh_request_body(self) -> Mapping[str, Any]: 533 """Returns the request body to set on the refresh request""" 534 535 @abstractmethod 536 def get_refresh_request_headers(self) -> Mapping[str, Any]: 537 """Returns the request headers to set on the refresh request""" 538 539 @abstractmethod 540 def get_grant_type(self) -> str: 541 """Returns grant_type specified for requesting access_token""" 542 543 @abstractmethod 544 def get_grant_type_name(self) -> str: 545 """Returns grant_type specified name for requesting access_token""" 546 547 @property 548 @abstractmethod 549 def access_token(self) -> str: 550 """Returns the access token""" 551 552 @access_token.setter 553 @abstractmethod 554 def access_token(self, value: str) -> str: 555 """Setter for the access token"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.
53 def __init__( 54 self, 55 refresh_token_error_status_codes: Tuple[int, ...] = (), 56 refresh_token_error_key: str = "", 57 refresh_token_error_values: Tuple[str, ...] = (), 58 ) -> None: 59 """ 60 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 61 then http errors with such params will be wrapped in AirbyteTracedException. 62 """ 63 self._refresh_token_error_status_codes = refresh_token_error_status_codes 64 self._refresh_token_error_key = refresh_token_error_key 65 self._refresh_token_error_values = refresh_token_error_values
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.
76 @property 77 def token_expiry_is_time_of_expiration(self) -> bool: 78 """ 79 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 80 """ 81 82 return False
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
84 @property 85 def token_expiry_date_format(self) -> Optional[str]: 86 """ 87 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 88 """ 89 90 return None
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
92 def get_auth_header(self) -> Mapping[str, Any]: 93 """HTTP header to set on the requests""" 94 token = self.access_token if self._is_access_token_flow else self.get_access_token() 95 return {"Authorization": f"Bearer {token}"}
HTTP header to set on the requests
97 def get_access_token(self) -> str: 98 """ 99 Returns the access token. 100 101 This method uses double-checked locking to ensure thread-safe token refresh. 102 When multiple threads (streams) detect an expired token simultaneously, only one 103 will perform the refresh while others wait. After acquiring the lock, the token 104 expiry is re-checked to avoid redundant refresh attempts. 105 """ 106 if self.token_has_expired(): 107 with self._token_refresh_lock: 108 # Double-check after acquiring lock - another thread may have already refreshed 109 if self.token_has_expired(): 110 self.refresh_and_set_access_token() 111 112 return self.access_token
Returns the access token.
This method uses double-checked locking to ensure thread-safe token refresh. When multiple threads (streams) detect an expired token simultaneously, only one will perform the refresh while others wait. After acquiring the lock, the token expiry is re-checked to avoid redundant refresh attempts.
114 def refresh_and_set_access_token(self) -> None: 115 """Force refresh the access token and update internal state. 116 117 This method refreshes the access token regardless of whether it has expired, 118 and updates the internal token and expiry date. Subclasses may override this 119 to handle additional state updates (e.g., persisting new refresh tokens). 120 """ 121 token, expires_in = self.refresh_access_token() 122 self.access_token = token 123 self.set_token_expiry_date(expires_in)
Force refresh the access token and update internal state.
This method refreshes the access token regardless of whether it has expired, and updates the internal token and expiry date. Subclasses may override this to handle additional state updates (e.g., persisting new refresh tokens).
125 def token_has_expired(self) -> bool: 126 """Returns True if the token is expired""" 127 return ab_datetime_now() > self.get_token_expiry_date()
Returns True if the token is expired
129 def build_refresh_request_body(self) -> Mapping[str, Any]: 130 """ 131 Returns the request body to set on the refresh request. 132 133 Override to define additional parameters. 134 135 Client credentials (client_id and client_secret) are excluded from the body when 136 refresh_request_headers contains an Authorization header (e.g., Basic auth). 137 This is required by OAuth providers like Gong that expect credentials ONLY in the 138 Authorization header and reject requests that include them in both places. 139 """ 140 # Check if credentials are being sent via Authorization header 141 headers = self.get_refresh_request_headers() 142 credentials_in_header = headers and "Authorization" in headers 143 144 # Only include client credentials in body if not already in header 145 include_client_credentials = not credentials_in_header 146 147 payload: MutableMapping[str, Any] = { 148 self.get_grant_type_name(): self.get_grant_type(), 149 } 150 151 # Only include client credentials in body if configured to do so and not in header 152 if include_client_credentials: 153 payload[self.get_client_id_name()] = self.get_client_id() 154 payload[self.get_client_secret_name()] = self.get_client_secret() 155 156 payload[self.get_refresh_token_name()] = self.get_refresh_token() 157 158 if self.get_scopes(): 159 payload["scopes"] = self.get_scopes() 160 161 if self.get_refresh_request_body(): 162 for key, val in self.get_refresh_request_body().items(): 163 # We defer to existing oauth constructs over custom configured fields 164 if key not in payload: 165 payload[key] = val 166 167 return payload
Returns the request body to set on the refresh request.
Override to define additional parameters.
Client credentials (client_id and client_secret) are excluded from the body when refresh_request_headers contains an Authorization header (e.g., Basic auth). This is required by OAuth providers like Gong that expect credentials ONLY in the Authorization header and reject requests that include them in both places.
169 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 170 """ 171 Returns the request headers to set on the refresh request 172 173 """ 174 headers = self.get_refresh_request_headers() 175 return headers if headers else None
Returns the request headers to set on the refresh request
177 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 178 """ 179 Returns the refresh token and its expiration datetime 180 181 :return: a tuple of (access_token, token_lifespan) 182 """ 183 response_json = self._make_handled_request() 184 self._ensure_access_token_in_response(response_json) 185 186 return ( 187 self._extract_access_token(response_json), 188 self._extract_token_expiry_date(response_json), 189 )
Returns the refresh token and its expiration datetime
Returns
a tuple of (access_token, token_lifespan)
483 @abstractmethod 484 def get_token_refresh_endpoint(self) -> Optional[str]: 485 """Returns the endpoint to refresh the access token"""
Returns the endpoint to refresh the access token
487 @abstractmethod 488 def get_client_id_name(self) -> str: 489 """The client id name to authenticate"""
The client id name to authenticate
495 @abstractmethod 496 def get_client_secret_name(self) -> str: 497 """The client secret name to authenticate"""
The client secret name to authenticate
499 @abstractmethod 500 def get_client_secret(self) -> str: 501 """The client secret to authenticate"""
The client secret to authenticate
503 @abstractmethod 504 def get_refresh_token_name(self) -> str: 505 """The refresh token name to authenticate"""
The refresh token name to authenticate
507 @abstractmethod 508 def get_refresh_token(self) -> Optional[str]: 509 """The token used to refresh the access token when it expires"""
The token used to refresh the access token when it expires
515 @abstractmethod 516 def get_token_expiry_date(self) -> AirbyteDateTime: 517 """Expiration date of the access token"""
Expiration date of the access token
519 @abstractmethod 520 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 521 """Setter for access token expiration date"""
Setter for access token expiration date
523 @abstractmethod 524 def get_access_token_name(self) -> str: 525 """Field to extract access token from in the response"""
Field to extract access token from in the response
527 @abstractmethod 528 def get_expires_in_name(self) -> str: 529 """Returns the expires_in field name"""
Returns the expires_in field name
531 @abstractmethod 532 def get_refresh_request_body(self) -> Mapping[str, Any]: 533 """Returns the request body to set on the refresh request"""
Returns the request body to set on the refresh request
535 @abstractmethod 536 def get_refresh_request_headers(self) -> Mapping[str, Any]: 537 """Returns the request headers to set on the refresh request"""
Returns the request headers to set on the refresh request
539 @abstractmethod 540 def get_grant_type(self) -> str: 541 """Returns grant_type specified for requesting access_token"""
Returns grant_type specified for requesting access_token