airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth
1# 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 3# 4 5import logging 6from abc import abstractmethod 7from datetime import timedelta 8from json import JSONDecodeError 9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union 10 11import backoff 12import requests 13from requests.auth import AuthBase 14 15from airbyte_cdk.models import FailureType, Level 16from airbyte_cdk.sources.http_logger import format_http_message 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository 18from airbyte_cdk.utils import AirbyteTracedException 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse 21 22from ..exceptions import DefaultBackoffException 23 24logger = logging.getLogger("airbyte") 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() 26 27 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 29 """ 30 Raised when the max level of recursion is reached, when trying to 31 find-and-get the target key, during the `_make_handled_request` 32 """ 33 34 35class AbstractOauth2Authenticator(AuthBase): 36 """ 37 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 38 is designed to generically perform the refresh flow without regard to how config fields are get/set by 39 delegating that behavior to the classes implementing the interface. 40 """ 41 42 _NO_STREAM_NAME = None 43 44 def __init__( 45 self, 46 refresh_token_error_status_codes: Tuple[int, ...] = (), 47 refresh_token_error_key: str = "", 48 refresh_token_error_values: Tuple[str, ...] = (), 49 ) -> None: 50 """ 51 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 52 then http errors with such params will be wrapped in AirbyteTracedException. 53 """ 54 self._refresh_token_error_status_codes = refresh_token_error_status_codes 55 self._refresh_token_error_key = refresh_token_error_key 56 self._refresh_token_error_values = refresh_token_error_values 57 58 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 59 """Attach the HTTP headers required to authenticate on the HTTP request""" 60 request.headers.update(self.get_auth_header()) 61 return request 62 63 @property 64 def _is_access_token_flow(self) -> bool: 65 return self.get_token_refresh_endpoint() is None and self.access_token is not None 66 67 @property 68 def token_expiry_is_time_of_expiration(self) -> bool: 69 """ 70 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 71 """ 72 73 return False 74 75 @property 76 def token_expiry_date_format(self) -> Optional[str]: 77 """ 78 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 79 """ 80 81 return None 82 83 def get_auth_header(self) -> Mapping[str, Any]: 84 """HTTP header to set on the requests""" 85 token = self.access_token if self._is_access_token_flow else self.get_access_token() 86 return {"Authorization": f"Bearer {token}"} 87 88 def get_access_token(self) -> str: 89 """Returns the access token""" 90 if self.token_has_expired(): 91 token, expires_in = self.refresh_access_token() 92 self.access_token = token 93 self.set_token_expiry_date(expires_in) 94 95 return self.access_token 96 97 def token_has_expired(self) -> bool: 98 """Returns True if the token is expired""" 99 return ab_datetime_now() > self.get_token_expiry_date() 100 101 def build_refresh_request_body(self) -> Mapping[str, Any]: 102 """ 103 Returns the request body to set on the refresh request 104 105 Override to define additional parameters 106 """ 107 payload: MutableMapping[str, Any] = { 108 self.get_grant_type_name(): self.get_grant_type(), 109 self.get_client_id_name(): self.get_client_id(), 110 self.get_client_secret_name(): self.get_client_secret(), 111 self.get_refresh_token_name(): self.get_refresh_token(), 112 } 113 114 if self.get_scopes(): 115 payload["scopes"] = self.get_scopes() 116 117 if self.get_refresh_request_body(): 118 for key, val in self.get_refresh_request_body().items(): 119 # We defer to existing oauth constructs over custom configured fields 120 if key not in payload: 121 payload[key] = val 122 123 return payload 124 125 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 126 """ 127 Returns the request headers to set on the refresh request 128 129 """ 130 headers = self.get_refresh_request_headers() 131 return headers if headers else None 132 133 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 134 """ 135 Returns the refresh token and its expiration datetime 136 137 :return: a tuple of (access_token, token_lifespan) 138 """ 139 response_json = self._make_handled_request() 140 self._ensure_access_token_in_response(response_json) 141 142 return ( 143 self._extract_access_token(response_json), 144 self._extract_token_expiry_date(response_json), 145 ) 146 147 # ---------------- 148 # PRIVATE METHODS 149 # ---------------- 150 151 def _default_token_expiry_date(self) -> AirbyteDateTime: 152 """ 153 Returns the default token expiry date 154 """ 155 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 156 default_token_expiry_duration_hours = 1 # 1 hour 157 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 158 159 def _wrap_refresh_token_exception( 160 self, exception: requests.exceptions.RequestException 161 ) -> bool: 162 """ 163 Wraps and handles exceptions that occur during the refresh token process. 164 165 This method checks if the provided exception is related to a refresh token error 166 by examining the response status code and specific error content. 167 168 Args: 169 exception (requests.exceptions.RequestException): The exception raised during the request. 170 171 Returns: 172 bool: True if the exception is related to a refresh token error, False otherwise. 173 """ 174 try: 175 if exception.response is not None: 176 exception_content = exception.response.json() 177 else: 178 return False 179 except JSONDecodeError: 180 return False 181 return ( 182 exception.response.status_code in self._refresh_token_error_status_codes 183 and exception_content.get(self._refresh_token_error_key) 184 in self._refresh_token_error_values 185 ) 186 187 @backoff.on_exception( 188 backoff.expo, 189 DefaultBackoffException, 190 on_backoff=lambda details: logger.info( 191 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 192 ), 193 max_time=300, 194 ) 195 def _make_handled_request(self) -> Any: 196 """ 197 Makes a handled HTTP request to refresh an OAuth token. 198 199 This method sends a POST request to the token refresh endpoint with the necessary 200 headers and body to obtain a new access token. It handles various exceptions that 201 may occur during the request and logs the response for troubleshooting purposes. 202 203 Returns: 204 Mapping[str, Any]: The JSON response from the token refresh endpoint. 205 206 Raises: 207 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 208 or any 5xx server error. 209 AirbyteTracedException: If the refresh token is invalid or expired, prompting 210 re-authentication. 211 Exception: For any other exceptions that occur during the request. 212 """ 213 try: 214 response = requests.request( 215 method="POST", 216 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 217 data=self.build_refresh_request_body(), 218 headers=self.build_refresh_request_headers(), 219 ) 220 # log the response even if the request failed for troubleshooting purposes 221 self._log_response(response) 222 response.raise_for_status() 223 return response.json() 224 except requests.exceptions.RequestException as e: 225 if e.response is not None: 226 if e.response.status_code == 429 or e.response.status_code >= 500: 227 raise DefaultBackoffException(request=e.response.request, response=e.response) 228 if self._wrap_refresh_token_exception(e): 229 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 230 raise AirbyteTracedException( 231 internal_message=message, message=message, failure_type=FailureType.config_error 232 ) 233 raise 234 except Exception as e: 235 raise Exception(f"Error while refreshing access token: {e}") from e 236 237 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 238 """ 239 Ensures that the access token is present in the response data. 240 241 This method attempts to extract the access token from the provided response data. 242 If the access token is not found, it raises an exception indicating that the token 243 refresh API response was missing the access token. If the access token is found, 244 it adds the token to the list of secrets to ensure it is replaced before logging 245 the response. 246 247 Args: 248 response_data (Mapping[str, Any]): The response data from which to extract the access token. 249 250 Raises: 251 Exception: If the access token is not found in the response data. 252 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 253 """ 254 try: 255 access_key = self._extract_access_token(response_data) 256 if not access_key: 257 raise Exception( 258 f"Token refresh API response was missing access token {self.get_access_token_name()}" 259 ) 260 # Add the access token to the list of secrets so it is replaced before logging the response 261 # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... 262 add_to_secrets(access_key) 263 except ResponseKeysMaxRecurtionReached as e: 264 raise e 265 266 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 267 """ 268 Parse a string or integer token expiration date into a datetime object 269 270 :return: expiration datetime 271 """ 272 if self.token_expiry_is_time_of_expiration: 273 if not self.token_expiry_date_format: 274 raise ValueError( 275 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 276 ) 277 try: 278 return ab_datetime_parse(str(value)) 279 except ValueError as e: 280 raise ValueError(f"Invalid token expiry date format: {e}") 281 else: 282 try: 283 # Only accept numeric values (as int/float/string) when no format specified 284 seconds = int(float(str(value))) 285 return ab_datetime_now() + timedelta(seconds=seconds) 286 except (ValueError, TypeError): 287 raise ValueError( 288 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 289 ) 290 291 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 292 """ 293 Extracts the access token from the given response data. 294 295 Args: 296 response_data (Mapping[str, Any]): The response data from which to extract the access token. 297 298 Returns: 299 str: The extracted access token. 300 """ 301 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 302 303 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 304 """ 305 Extracts the refresh token from the given response data. 306 307 Args: 308 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 309 310 Returns: 311 str: The extracted refresh token. 312 """ 313 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 314 315 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 316 """ 317 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 318 319 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 320 321 Args: 322 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 323 324 Returns: 325 The extracted token_expiry_date or None if not found. 326 """ 327 expires_in = self._find_and_get_value_from_response( 328 response_data, self.get_expires_in_name() 329 ) 330 if expires_in is not None: 331 return self._parse_token_expiration_date(expires_in) 332 333 # expires_in is None 334 existing_expiry_date = self.get_token_expiry_date() 335 if existing_expiry_date and not self.token_has_expired(): 336 return existing_expiry_date 337 338 return self._default_token_expiry_date() 339 340 def _find_and_get_value_from_response( 341 self, 342 response_data: Mapping[str, Any], 343 key_name: str, 344 max_depth: int = 5, 345 current_depth: int = 0, 346 ) -> Any: 347 """ 348 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 349 350 Args: 351 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 352 key_name (str): The key to search for in the response data. 353 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 354 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 355 356 Returns: 357 Any: The value associated with the specified key if found, otherwise None. 358 359 Raises: 360 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 361 """ 362 if current_depth > max_depth: 363 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 364 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 365 raise ResponseKeysMaxRecurtionReached( 366 internal_message=message, message=message, failure_type=FailureType.config_error 367 ) 368 369 if isinstance(response_data, dict): 370 # get from the root level 371 if key_name in response_data: 372 return response_data[key_name] 373 374 # get from the nested object 375 for _, value in response_data.items(): 376 result = self._find_and_get_value_from_response( 377 value, key_name, max_depth, current_depth + 1 378 ) 379 if result is not None: 380 return result 381 382 # get from the nested array object 383 elif isinstance(response_data, list): 384 for item in response_data: 385 result = self._find_and_get_value_from_response( 386 item, key_name, max_depth, current_depth + 1 387 ) 388 if result is not None: 389 return result 390 391 return None 392 393 @property 394 def _message_repository(self) -> Optional[MessageRepository]: 395 """ 396 The implementation can define a message_repository if it wants debugging logs for HTTP requests 397 """ 398 return _NOOP_MESSAGE_REPOSITORY 399 400 def _log_response(self, response: requests.Response) -> None: 401 """ 402 Logs the HTTP response using the message repository if it is available. 403 404 Args: 405 response (requests.Response): The HTTP response to log. 406 """ 407 if self._message_repository: 408 self._message_repository.log_message( 409 Level.DEBUG, 410 lambda: format_http_message( 411 response, 412 "Refresh token", 413 "Obtains access token", 414 self._NO_STREAM_NAME, 415 is_auxiliary=True, 416 type="AUTH", 417 ), 418 ) 419 420 # ---------------- 421 # ABSTR METHODS 422 # ---------------- 423 424 @abstractmethod 425 def get_token_refresh_endpoint(self) -> Optional[str]: 426 """Returns the endpoint to refresh the access token""" 427 428 @abstractmethod 429 def get_client_id_name(self) -> str: 430 """The client id name to authenticate""" 431 432 @abstractmethod 433 def get_client_id(self) -> str: 434 """The client id to authenticate""" 435 436 @abstractmethod 437 def get_client_secret_name(self) -> str: 438 """The client secret name to authenticate""" 439 440 @abstractmethod 441 def get_client_secret(self) -> str: 442 """The client secret to authenticate""" 443 444 @abstractmethod 445 def get_refresh_token_name(self) -> str: 446 """The refresh token name to authenticate""" 447 448 @abstractmethod 449 def get_refresh_token(self) -> Optional[str]: 450 """The token used to refresh the access token when it expires""" 451 452 @abstractmethod 453 def get_scopes(self) -> List[str]: 454 """List of requested scopes""" 455 456 @abstractmethod 457 def get_token_expiry_date(self) -> AirbyteDateTime: 458 """Expiration date of the access token""" 459 460 @abstractmethod 461 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 462 """Setter for access token expiration date""" 463 464 @abstractmethod 465 def get_access_token_name(self) -> str: 466 """Field to extract access token from in the response""" 467 468 @abstractmethod 469 def get_expires_in_name(self) -> str: 470 """Returns the expires_in field name""" 471 472 @abstractmethod 473 def get_refresh_request_body(self) -> Mapping[str, Any]: 474 """Returns the request body to set on the refresh request""" 475 476 @abstractmethod 477 def get_refresh_request_headers(self) -> Mapping[str, Any]: 478 """Returns the request headers to set on the refresh request""" 479 480 @abstractmethod 481 def get_grant_type(self) -> str: 482 """Returns grant_type specified for requesting access_token""" 483 484 @abstractmethod 485 def get_grant_type_name(self) -> str: 486 """Returns grant_type specified name for requesting access_token""" 487 488 @property 489 @abstractmethod 490 def access_token(self) -> str: 491 """Returns the access token""" 492 493 @access_token.setter 494 @abstractmethod 495 def access_token(self, value: str) -> str: 496 """Setter for the access token"""
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 30 """ 31 Raised when the max level of recursion is reached, when trying to 32 find-and-get the target key, during the `_make_handled_request` 33 """
Raised when the max level of recursion is reached, when trying to
find-and-get the target key, during the _make_handled_request
36class AbstractOauth2Authenticator(AuthBase): 37 """ 38 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 39 is designed to generically perform the refresh flow without regard to how config fields are get/set by 40 delegating that behavior to the classes implementing the interface. 41 """ 42 43 _NO_STREAM_NAME = None 44 45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values 58 59 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 60 """Attach the HTTP headers required to authenticate on the HTTP request""" 61 request.headers.update(self.get_auth_header()) 62 return request 63 64 @property 65 def _is_access_token_flow(self) -> bool: 66 return self.get_token_refresh_endpoint() is None and self.access_token is not None 67 68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False 75 76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None 83 84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"} 88 89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token 97 98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date() 101 102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request 105 106 Override to define additional parameters 107 """ 108 payload: MutableMapping[str, Any] = { 109 self.get_grant_type_name(): self.get_grant_type(), 110 self.get_client_id_name(): self.get_client_id(), 111 self.get_client_secret_name(): self.get_client_secret(), 112 self.get_refresh_token_name(): self.get_refresh_token(), 113 } 114 115 if self.get_scopes(): 116 payload["scopes"] = self.get_scopes() 117 118 if self.get_refresh_request_body(): 119 for key, val in self.get_refresh_request_body().items(): 120 # We defer to existing oauth constructs over custom configured fields 121 if key not in payload: 122 payload[key] = val 123 124 return payload 125 126 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 127 """ 128 Returns the request headers to set on the refresh request 129 130 """ 131 headers = self.get_refresh_request_headers() 132 return headers if headers else None 133 134 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 135 """ 136 Returns the refresh token and its expiration datetime 137 138 :return: a tuple of (access_token, token_lifespan) 139 """ 140 response_json = self._make_handled_request() 141 self._ensure_access_token_in_response(response_json) 142 143 return ( 144 self._extract_access_token(response_json), 145 self._extract_token_expiry_date(response_json), 146 ) 147 148 # ---------------- 149 # PRIVATE METHODS 150 # ---------------- 151 152 def _default_token_expiry_date(self) -> AirbyteDateTime: 153 """ 154 Returns the default token expiry date 155 """ 156 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 157 default_token_expiry_duration_hours = 1 # 1 hour 158 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 159 160 def _wrap_refresh_token_exception( 161 self, exception: requests.exceptions.RequestException 162 ) -> bool: 163 """ 164 Wraps and handles exceptions that occur during the refresh token process. 165 166 This method checks if the provided exception is related to a refresh token error 167 by examining the response status code and specific error content. 168 169 Args: 170 exception (requests.exceptions.RequestException): The exception raised during the request. 171 172 Returns: 173 bool: True if the exception is related to a refresh token error, False otherwise. 174 """ 175 try: 176 if exception.response is not None: 177 exception_content = exception.response.json() 178 else: 179 return False 180 except JSONDecodeError: 181 return False 182 return ( 183 exception.response.status_code in self._refresh_token_error_status_codes 184 and exception_content.get(self._refresh_token_error_key) 185 in self._refresh_token_error_values 186 ) 187 188 @backoff.on_exception( 189 backoff.expo, 190 DefaultBackoffException, 191 on_backoff=lambda details: logger.info( 192 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 193 ), 194 max_time=300, 195 ) 196 def _make_handled_request(self) -> Any: 197 """ 198 Makes a handled HTTP request to refresh an OAuth token. 199 200 This method sends a POST request to the token refresh endpoint with the necessary 201 headers and body to obtain a new access token. It handles various exceptions that 202 may occur during the request and logs the response for troubleshooting purposes. 203 204 Returns: 205 Mapping[str, Any]: The JSON response from the token refresh endpoint. 206 207 Raises: 208 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 209 or any 5xx server error. 210 AirbyteTracedException: If the refresh token is invalid or expired, prompting 211 re-authentication. 212 Exception: For any other exceptions that occur during the request. 213 """ 214 try: 215 response = requests.request( 216 method="POST", 217 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 218 data=self.build_refresh_request_body(), 219 headers=self.build_refresh_request_headers(), 220 ) 221 # log the response even if the request failed for troubleshooting purposes 222 self._log_response(response) 223 response.raise_for_status() 224 return response.json() 225 except requests.exceptions.RequestException as e: 226 if e.response is not None: 227 if e.response.status_code == 429 or e.response.status_code >= 500: 228 raise DefaultBackoffException(request=e.response.request, response=e.response) 229 if self._wrap_refresh_token_exception(e): 230 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 231 raise AirbyteTracedException( 232 internal_message=message, message=message, failure_type=FailureType.config_error 233 ) 234 raise 235 except Exception as e: 236 raise Exception(f"Error while refreshing access token: {e}") from e 237 238 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 239 """ 240 Ensures that the access token is present in the response data. 241 242 This method attempts to extract the access token from the provided response data. 243 If the access token is not found, it raises an exception indicating that the token 244 refresh API response was missing the access token. If the access token is found, 245 it adds the token to the list of secrets to ensure it is replaced before logging 246 the response. 247 248 Args: 249 response_data (Mapping[str, Any]): The response data from which to extract the access token. 250 251 Raises: 252 Exception: If the access token is not found in the response data. 253 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 254 """ 255 try: 256 access_key = self._extract_access_token(response_data) 257 if not access_key: 258 raise Exception( 259 f"Token refresh API response was missing access token {self.get_access_token_name()}" 260 ) 261 # Add the access token to the list of secrets so it is replaced before logging the response 262 # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... 263 add_to_secrets(access_key) 264 except ResponseKeysMaxRecurtionReached as e: 265 raise e 266 267 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 268 """ 269 Parse a string or integer token expiration date into a datetime object 270 271 :return: expiration datetime 272 """ 273 if self.token_expiry_is_time_of_expiration: 274 if not self.token_expiry_date_format: 275 raise ValueError( 276 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 277 ) 278 try: 279 return ab_datetime_parse(str(value)) 280 except ValueError as e: 281 raise ValueError(f"Invalid token expiry date format: {e}") 282 else: 283 try: 284 # Only accept numeric values (as int/float/string) when no format specified 285 seconds = int(float(str(value))) 286 return ab_datetime_now() + timedelta(seconds=seconds) 287 except (ValueError, TypeError): 288 raise ValueError( 289 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 290 ) 291 292 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 293 """ 294 Extracts the access token from the given response data. 295 296 Args: 297 response_data (Mapping[str, Any]): The response data from which to extract the access token. 298 299 Returns: 300 str: The extracted access token. 301 """ 302 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 303 304 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 305 """ 306 Extracts the refresh token from the given response data. 307 308 Args: 309 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 310 311 Returns: 312 str: The extracted refresh token. 313 """ 314 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 315 316 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 317 """ 318 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 319 320 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 321 322 Args: 323 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 324 325 Returns: 326 The extracted token_expiry_date or None if not found. 327 """ 328 expires_in = self._find_and_get_value_from_response( 329 response_data, self.get_expires_in_name() 330 ) 331 if expires_in is not None: 332 return self._parse_token_expiration_date(expires_in) 333 334 # expires_in is None 335 existing_expiry_date = self.get_token_expiry_date() 336 if existing_expiry_date and not self.token_has_expired(): 337 return existing_expiry_date 338 339 return self._default_token_expiry_date() 340 341 def _find_and_get_value_from_response( 342 self, 343 response_data: Mapping[str, Any], 344 key_name: str, 345 max_depth: int = 5, 346 current_depth: int = 0, 347 ) -> Any: 348 """ 349 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 350 351 Args: 352 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 353 key_name (str): The key to search for in the response data. 354 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 355 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 356 357 Returns: 358 Any: The value associated with the specified key if found, otherwise None. 359 360 Raises: 361 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 362 """ 363 if current_depth > max_depth: 364 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 365 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 366 raise ResponseKeysMaxRecurtionReached( 367 internal_message=message, message=message, failure_type=FailureType.config_error 368 ) 369 370 if isinstance(response_data, dict): 371 # get from the root level 372 if key_name in response_data: 373 return response_data[key_name] 374 375 # get from the nested object 376 for _, value in response_data.items(): 377 result = self._find_and_get_value_from_response( 378 value, key_name, max_depth, current_depth + 1 379 ) 380 if result is not None: 381 return result 382 383 # get from the nested array object 384 elif isinstance(response_data, list): 385 for item in response_data: 386 result = self._find_and_get_value_from_response( 387 item, key_name, max_depth, current_depth + 1 388 ) 389 if result is not None: 390 return result 391 392 return None 393 394 @property 395 def _message_repository(self) -> Optional[MessageRepository]: 396 """ 397 The implementation can define a message_repository if it wants debugging logs for HTTP requests 398 """ 399 return _NOOP_MESSAGE_REPOSITORY 400 401 def _log_response(self, response: requests.Response) -> None: 402 """ 403 Logs the HTTP response using the message repository if it is available. 404 405 Args: 406 response (requests.Response): The HTTP response to log. 407 """ 408 if self._message_repository: 409 self._message_repository.log_message( 410 Level.DEBUG, 411 lambda: format_http_message( 412 response, 413 "Refresh token", 414 "Obtains access token", 415 self._NO_STREAM_NAME, 416 is_auxiliary=True, 417 type="AUTH", 418 ), 419 ) 420 421 # ---------------- 422 # ABSTR METHODS 423 # ---------------- 424 425 @abstractmethod 426 def get_token_refresh_endpoint(self) -> Optional[str]: 427 """Returns the endpoint to refresh the access token""" 428 429 @abstractmethod 430 def get_client_id_name(self) -> str: 431 """The client id name to authenticate""" 432 433 @abstractmethod 434 def get_client_id(self) -> str: 435 """The client id to authenticate""" 436 437 @abstractmethod 438 def get_client_secret_name(self) -> str: 439 """The client secret name to authenticate""" 440 441 @abstractmethod 442 def get_client_secret(self) -> str: 443 """The client secret to authenticate""" 444 445 @abstractmethod 446 def get_refresh_token_name(self) -> str: 447 """The refresh token name to authenticate""" 448 449 @abstractmethod 450 def get_refresh_token(self) -> Optional[str]: 451 """The token used to refresh the access token when it expires""" 452 453 @abstractmethod 454 def get_scopes(self) -> List[str]: 455 """List of requested scopes""" 456 457 @abstractmethod 458 def get_token_expiry_date(self) -> AirbyteDateTime: 459 """Expiration date of the access token""" 460 461 @abstractmethod 462 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 463 """Setter for access token expiration date""" 464 465 @abstractmethod 466 def get_access_token_name(self) -> str: 467 """Field to extract access token from in the response""" 468 469 @abstractmethod 470 def get_expires_in_name(self) -> str: 471 """Returns the expires_in field name""" 472 473 @abstractmethod 474 def get_refresh_request_body(self) -> Mapping[str, Any]: 475 """Returns the request body to set on the refresh request""" 476 477 @abstractmethod 478 def get_refresh_request_headers(self) -> Mapping[str, Any]: 479 """Returns the request headers to set on the refresh request""" 480 481 @abstractmethod 482 def get_grant_type(self) -> str: 483 """Returns grant_type specified for requesting access_token""" 484 485 @abstractmethod 486 def get_grant_type_name(self) -> str: 487 """Returns grant_type specified name for requesting access_token""" 488 489 @property 490 @abstractmethod 491 def access_token(self) -> str: 492 """Returns the access token""" 493 494 @access_token.setter 495 @abstractmethod 496 def access_token(self, value: str) -> str: 497 """Setter for the access token"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.
45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.
68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"}
HTTP header to set on the requests
89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token
Returns the access token
98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date()
Returns True if the token is expired
102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request 105 106 Override to define additional parameters 107 """ 108 payload: MutableMapping[str, Any] = { 109 self.get_grant_type_name(): self.get_grant_type(), 110 self.get_client_id_name(): self.get_client_id(), 111 self.get_client_secret_name(): self.get_client_secret(), 112 self.get_refresh_token_name(): self.get_refresh_token(), 113 } 114 115 if self.get_scopes(): 116 payload["scopes"] = self.get_scopes() 117 118 if self.get_refresh_request_body(): 119 for key, val in self.get_refresh_request_body().items(): 120 # We defer to existing oauth constructs over custom configured fields 121 if key not in payload: 122 payload[key] = val 123 124 return payload
Returns the request body to set on the refresh request
Override to define additional parameters
126 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 127 """ 128 Returns the request headers to set on the refresh request 129 130 """ 131 headers = self.get_refresh_request_headers() 132 return headers if headers else None
Returns the request headers to set on the refresh request
134 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 135 """ 136 Returns the refresh token and its expiration datetime 137 138 :return: a tuple of (access_token, token_lifespan) 139 """ 140 response_json = self._make_handled_request() 141 self._ensure_access_token_in_response(response_json) 142 143 return ( 144 self._extract_access_token(response_json), 145 self._extract_token_expiry_date(response_json), 146 )
Returns the refresh token and its expiration datetime
Returns
a tuple of (access_token, token_lifespan)
425 @abstractmethod 426 def get_token_refresh_endpoint(self) -> Optional[str]: 427 """Returns the endpoint to refresh the access token"""
Returns the endpoint to refresh the access token
429 @abstractmethod 430 def get_client_id_name(self) -> str: 431 """The client id name to authenticate"""
The client id name to authenticate
437 @abstractmethod 438 def get_client_secret_name(self) -> str: 439 """The client secret name to authenticate"""
The client secret name to authenticate
441 @abstractmethod 442 def get_client_secret(self) -> str: 443 """The client secret to authenticate"""
The client secret to authenticate
445 @abstractmethod 446 def get_refresh_token_name(self) -> str: 447 """The refresh token name to authenticate"""
The refresh token name to authenticate
449 @abstractmethod 450 def get_refresh_token(self) -> Optional[str]: 451 """The token used to refresh the access token when it expires"""
The token used to refresh the access token when it expires
457 @abstractmethod 458 def get_token_expiry_date(self) -> AirbyteDateTime: 459 """Expiration date of the access token"""
Expiration date of the access token
461 @abstractmethod 462 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 463 """Setter for access token expiration date"""
Setter for access token expiration date
465 @abstractmethod 466 def get_access_token_name(self) -> str: 467 """Field to extract access token from in the response"""
Field to extract access token from in the response
469 @abstractmethod 470 def get_expires_in_name(self) -> str: 471 """Returns the expires_in field name"""
Returns the expires_in field name
473 @abstractmethod 474 def get_refresh_request_body(self) -> Mapping[str, Any]: 475 """Returns the request body to set on the refresh request"""
Returns the request body to set on the refresh request
477 @abstractmethod 478 def get_refresh_request_headers(self) -> Mapping[str, Any]: 479 """Returns the request headers to set on the refresh request"""
Returns the request headers to set on the refresh request
481 @abstractmethod 482 def get_grant_type(self) -> str: 483 """Returns grant_type specified for requesting access_token"""
Returns grant_type specified for requesting access_token