airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth
1# 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 3# 4 5import logging 6from abc import abstractmethod 7from datetime import timedelta 8from json import JSONDecodeError 9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union 10 11import backoff 12import requests 13from requests.auth import AuthBase 14 15from airbyte_cdk.models import FailureType, Level 16from airbyte_cdk.sources.http_logger import format_http_message 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository 18from airbyte_cdk.utils import AirbyteTracedException 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse 21 22from ..exceptions import DefaultBackoffException 23 24logger = logging.getLogger("airbyte") 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() 26 27 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 29 """ 30 Raised when the max level of recursion is reached, when trying to 31 find-and-get the target key, during the `_make_handled_request` 32 """ 33 34 35class AbstractOauth2Authenticator(AuthBase): 36 """ 37 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 38 is designed to generically perform the refresh flow without regard to how config fields are get/set by 39 delegating that behavior to the classes implementing the interface. 40 """ 41 42 _NO_STREAM_NAME = None 43 44 def __init__( 45 self, 46 refresh_token_error_status_codes: Tuple[int, ...] = (), 47 refresh_token_error_key: str = "", 48 refresh_token_error_values: Tuple[str, ...] = (), 49 ) -> None: 50 """ 51 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 52 then http errors with such params will be wrapped in AirbyteTracedException. 53 """ 54 self._refresh_token_error_status_codes = refresh_token_error_status_codes 55 self._refresh_token_error_key = refresh_token_error_key 56 self._refresh_token_error_values = refresh_token_error_values 57 58 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 59 """Attach the HTTP headers required to authenticate on the HTTP request""" 60 request.headers.update(self.get_auth_header()) 61 return request 62 63 @property 64 def _is_access_token_flow(self) -> bool: 65 return self.get_token_refresh_endpoint() is None and self.access_token is not None 66 67 @property 68 def token_expiry_is_time_of_expiration(self) -> bool: 69 """ 70 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 71 """ 72 73 return False 74 75 @property 76 def token_expiry_date_format(self) -> Optional[str]: 77 """ 78 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 79 """ 80 81 return None 82 83 def get_auth_header(self) -> Mapping[str, Any]: 84 """HTTP header to set on the requests""" 85 token = self.access_token if self._is_access_token_flow else self.get_access_token() 86 return {"Authorization": f"Bearer {token}"} 87 88 def get_access_token(self) -> str: 89 """Returns the access token""" 90 if self.token_has_expired(): 91 token, expires_in = self.refresh_access_token() 92 self.access_token = token 93 self.set_token_expiry_date(expires_in) 94 95 return self.access_token 96 97 def token_has_expired(self) -> bool: 98 """Returns True if the token is expired""" 99 return ab_datetime_now() > self.get_token_expiry_date() 100 101 def build_refresh_request_body(self) -> Mapping[str, Any]: 102 """ 103 Returns the request body to set on the refresh request 104 105 Override to define additional parameters 106 """ 107 payload: MutableMapping[str, Any] = { 108 self.get_grant_type_name(): self.get_grant_type(), 109 self.get_client_id_name(): self.get_client_id(), 110 self.get_client_secret_name(): self.get_client_secret(), 111 self.get_refresh_token_name(): self.get_refresh_token(), 112 } 113 114 if self.get_scopes(): 115 payload["scopes"] = self.get_scopes() 116 117 if self.get_refresh_request_body(): 118 for key, val in self.get_refresh_request_body().items(): 119 # We defer to existing oauth constructs over custom configured fields 120 if key not in payload: 121 payload[key] = val 122 123 return payload 124 125 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 126 """ 127 Returns the request headers to set on the refresh request 128 129 """ 130 headers = self.get_refresh_request_headers() 131 return headers if headers else None 132 133 def refresh_access_token(self) -> Tuple[str, Union[str, int]]: 134 """ 135 Returns the refresh token and its expiration datetime 136 137 :return: a tuple of (access_token, token_lifespan) 138 """ 139 response_json = self._make_handled_request() 140 self._ensure_access_token_in_response(response_json) 141 142 return ( 143 self._extract_access_token(response_json), 144 self._extract_token_expiry_date(response_json), 145 ) 146 147 # ---------------- 148 # PRIVATE METHODS 149 # ---------------- 150 151 def _wrap_refresh_token_exception( 152 self, exception: requests.exceptions.RequestException 153 ) -> bool: 154 """ 155 Wraps and handles exceptions that occur during the refresh token process. 156 157 This method checks if the provided exception is related to a refresh token error 158 by examining the response status code and specific error content. 159 160 Args: 161 exception (requests.exceptions.RequestException): The exception raised during the request. 162 163 Returns: 164 bool: True if the exception is related to a refresh token error, False otherwise. 165 """ 166 try: 167 if exception.response is not None: 168 exception_content = exception.response.json() 169 else: 170 return False 171 except JSONDecodeError: 172 return False 173 return ( 174 exception.response.status_code in self._refresh_token_error_status_codes 175 and exception_content.get(self._refresh_token_error_key) 176 in self._refresh_token_error_values 177 ) 178 179 @backoff.on_exception( 180 backoff.expo, 181 DefaultBackoffException, 182 on_backoff=lambda details: logger.info( 183 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 184 ), 185 max_time=300, 186 ) 187 def _make_handled_request(self) -> Any: 188 """ 189 Makes a handled HTTP request to refresh an OAuth token. 190 191 This method sends a POST request to the token refresh endpoint with the necessary 192 headers and body to obtain a new access token. It handles various exceptions that 193 may occur during the request and logs the response for troubleshooting purposes. 194 195 Returns: 196 Mapping[str, Any]: The JSON response from the token refresh endpoint. 197 198 Raises: 199 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 200 or any 5xx server error. 201 AirbyteTracedException: If the refresh token is invalid or expired, prompting 202 re-authentication. 203 Exception: For any other exceptions that occur during the request. 204 """ 205 try: 206 response = requests.request( 207 method="POST", 208 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 209 data=self.build_refresh_request_body(), 210 headers=self.build_refresh_request_headers(), 211 ) 212 # log the response even if the request failed for troubleshooting purposes 213 self._log_response(response) 214 response.raise_for_status() 215 return response.json() 216 except requests.exceptions.RequestException as e: 217 if e.response is not None: 218 if e.response.status_code == 429 or e.response.status_code >= 500: 219 raise DefaultBackoffException(request=e.response.request, response=e.response) 220 if self._wrap_refresh_token_exception(e): 221 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 222 raise AirbyteTracedException( 223 internal_message=message, message=message, failure_type=FailureType.config_error 224 ) 225 raise 226 except Exception as e: 227 raise Exception(f"Error while refreshing access token: {e}") from e 228 229 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 230 """ 231 Ensures that the access token is present in the response data. 232 233 This method attempts to extract the access token from the provided response data. 234 If the access token is not found, it raises an exception indicating that the token 235 refresh API response was missing the access token. If the access token is found, 236 it adds the token to the list of secrets to ensure it is replaced before logging 237 the response. 238 239 Args: 240 response_data (Mapping[str, Any]): The response data from which to extract the access token. 241 242 Raises: 243 Exception: If the access token is not found in the response data. 244 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 245 """ 246 try: 247 access_key = self._extract_access_token(response_data) 248 if not access_key: 249 raise Exception( 250 f"Token refresh API response was missing access token {self.get_access_token_name()}" 251 ) 252 # Add the access token to the list of secrets so it is replaced before logging the response 253 # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... 254 add_to_secrets(access_key) 255 except ResponseKeysMaxRecurtionReached as e: 256 raise e 257 258 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 259 """ 260 Return the expiration datetime of the refresh token 261 262 :return: expiration datetime 263 """ 264 if not value and not self.token_has_expired(): 265 # No expiry token was provided but the previous one is not expired so it's fine 266 return self.get_token_expiry_date() 267 268 if self.token_expiry_is_time_of_expiration: 269 if not self.token_expiry_date_format: 270 raise ValueError( 271 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 272 ) 273 try: 274 return ab_datetime_parse(str(value)) 275 except ValueError as e: 276 raise ValueError(f"Invalid token expiry date format: {e}") 277 else: 278 try: 279 # Only accept numeric values (as int/float/string) when no format specified 280 seconds = int(float(str(value))) 281 return ab_datetime_now() + timedelta(seconds=seconds) 282 except (ValueError, TypeError): 283 raise ValueError( 284 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 285 ) 286 287 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 288 """ 289 Extracts the access token from the given response data. 290 291 Args: 292 response_data (Mapping[str, Any]): The response data from which to extract the access token. 293 294 Returns: 295 str: The extracted access token. 296 """ 297 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 298 299 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 300 """ 301 Extracts the refresh token from the given response data. 302 303 Args: 304 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 305 306 Returns: 307 str: The extracted refresh token. 308 """ 309 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 310 311 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: 312 """ 313 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 314 315 Args: 316 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 317 318 Returns: 319 str: The extracted token_expiry_date. 320 """ 321 return self._find_and_get_value_from_response(response_data, self.get_expires_in_name()) 322 323 def _find_and_get_value_from_response( 324 self, 325 response_data: Mapping[str, Any], 326 key_name: str, 327 max_depth: int = 5, 328 current_depth: int = 0, 329 ) -> Any: 330 """ 331 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 332 333 Args: 334 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 335 key_name (str): The key to search for in the response data. 336 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 337 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 338 339 Returns: 340 Any: The value associated with the specified key if found, otherwise None. 341 342 Raises: 343 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 344 """ 345 if current_depth > max_depth: 346 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 347 message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response." 348 raise ResponseKeysMaxRecurtionReached( 349 internal_message=message, message=message, failure_type=FailureType.config_error 350 ) 351 352 if isinstance(response_data, dict): 353 # get from the root level 354 if key_name in response_data: 355 return response_data[key_name] 356 357 # get from the nested object 358 for _, value in response_data.items(): 359 result = self._find_and_get_value_from_response( 360 value, key_name, max_depth, current_depth + 1 361 ) 362 if result is not None: 363 return result 364 365 # get from the nested array object 366 elif isinstance(response_data, list): 367 for item in response_data: 368 result = self._find_and_get_value_from_response( 369 item, key_name, max_depth, current_depth + 1 370 ) 371 if result is not None: 372 return result 373 374 return None 375 376 @property 377 def _message_repository(self) -> Optional[MessageRepository]: 378 """ 379 The implementation can define a message_repository if it wants debugging logs for HTTP requests 380 """ 381 return _NOOP_MESSAGE_REPOSITORY 382 383 def _log_response(self, response: requests.Response) -> None: 384 """ 385 Logs the HTTP response using the message repository if it is available. 386 387 Args: 388 response (requests.Response): The HTTP response to log. 389 """ 390 if self._message_repository: 391 self._message_repository.log_message( 392 Level.DEBUG, 393 lambda: format_http_message( 394 response, 395 "Refresh token", 396 "Obtains access token", 397 self._NO_STREAM_NAME, 398 is_auxiliary=True, 399 type="AUTH", 400 ), 401 ) 402 403 # ---------------- 404 # ABSTR METHODS 405 # ---------------- 406 407 @abstractmethod 408 def get_token_refresh_endpoint(self) -> Optional[str]: 409 """Returns the endpoint to refresh the access token""" 410 411 @abstractmethod 412 def get_client_id_name(self) -> str: 413 """The client id name to authenticate""" 414 415 @abstractmethod 416 def get_client_id(self) -> str: 417 """The client id to authenticate""" 418 419 @abstractmethod 420 def get_client_secret_name(self) -> str: 421 """The client secret name to authenticate""" 422 423 @abstractmethod 424 def get_client_secret(self) -> str: 425 """The client secret to authenticate""" 426 427 @abstractmethod 428 def get_refresh_token_name(self) -> str: 429 """The refresh token name to authenticate""" 430 431 @abstractmethod 432 def get_refresh_token(self) -> Optional[str]: 433 """The token used to refresh the access token when it expires""" 434 435 @abstractmethod 436 def get_scopes(self) -> List[str]: 437 """List of requested scopes""" 438 439 @abstractmethod 440 def get_token_expiry_date(self) -> AirbyteDateTime: 441 """Expiration date of the access token""" 442 443 @abstractmethod 444 def set_token_expiry_date(self, value: Union[str, int]) -> None: 445 """Setter for access token expiration date""" 446 447 @abstractmethod 448 def get_access_token_name(self) -> str: 449 """Field to extract access token from in the response""" 450 451 @abstractmethod 452 def get_expires_in_name(self) -> str: 453 """Returns the expires_in field name""" 454 455 @abstractmethod 456 def get_refresh_request_body(self) -> Mapping[str, Any]: 457 """Returns the request body to set on the refresh request""" 458 459 @abstractmethod 460 def get_refresh_request_headers(self) -> Mapping[str, Any]: 461 """Returns the request headers to set on the refresh request""" 462 463 @abstractmethod 464 def get_grant_type(self) -> str: 465 """Returns grant_type specified for requesting access_token""" 466 467 @abstractmethod 468 def get_grant_type_name(self) -> str: 469 """Returns grant_type specified name for requesting access_token""" 470 471 @property 472 @abstractmethod 473 def access_token(self) -> str: 474 """Returns the access token""" 475 476 @access_token.setter 477 @abstractmethod 478 def access_token(self, value: str) -> str: 479 """Setter for the access token"""
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 30 """ 31 Raised when the max level of recursion is reached, when trying to 32 find-and-get the target key, during the `_make_handled_request` 33 """
Raised when the max level of recursion is reached, when trying to
find-and-get the target key, during the _make_handled_request
36class AbstractOauth2Authenticator(AuthBase): 37 """ 38 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 39 is designed to generically perform the refresh flow without regard to how config fields are get/set by 40 delegating that behavior to the classes implementing the interface. 41 """ 42 43 _NO_STREAM_NAME = None 44 45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values 58 59 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 60 """Attach the HTTP headers required to authenticate on the HTTP request""" 61 request.headers.update(self.get_auth_header()) 62 return request 63 64 @property 65 def _is_access_token_flow(self) -> bool: 66 return self.get_token_refresh_endpoint() is None and self.access_token is not None 67 68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False 75 76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None 83 84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"} 88 89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token 97 98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date() 101 102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request 105 106 Override to define additional parameters 107 """ 108 payload: MutableMapping[str, Any] = { 109 self.get_grant_type_name(): self.get_grant_type(), 110 self.get_client_id_name(): self.get_client_id(), 111 self.get_client_secret_name(): self.get_client_secret(), 112 self.get_refresh_token_name(): self.get_refresh_token(), 113 } 114 115 if self.get_scopes(): 116 payload["scopes"] = self.get_scopes() 117 118 if self.get_refresh_request_body(): 119 for key, val in self.get_refresh_request_body().items(): 120 # We defer to existing oauth constructs over custom configured fields 121 if key not in payload: 122 payload[key] = val 123 124 return payload 125 126 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 127 """ 128 Returns the request headers to set on the refresh request 129 130 """ 131 headers = self.get_refresh_request_headers() 132 return headers if headers else None 133 134 def refresh_access_token(self) -> Tuple[str, Union[str, int]]: 135 """ 136 Returns the refresh token and its expiration datetime 137 138 :return: a tuple of (access_token, token_lifespan) 139 """ 140 response_json = self._make_handled_request() 141 self._ensure_access_token_in_response(response_json) 142 143 return ( 144 self._extract_access_token(response_json), 145 self._extract_token_expiry_date(response_json), 146 ) 147 148 # ---------------- 149 # PRIVATE METHODS 150 # ---------------- 151 152 def _wrap_refresh_token_exception( 153 self, exception: requests.exceptions.RequestException 154 ) -> bool: 155 """ 156 Wraps and handles exceptions that occur during the refresh token process. 157 158 This method checks if the provided exception is related to a refresh token error 159 by examining the response status code and specific error content. 160 161 Args: 162 exception (requests.exceptions.RequestException): The exception raised during the request. 163 164 Returns: 165 bool: True if the exception is related to a refresh token error, False otherwise. 166 """ 167 try: 168 if exception.response is not None: 169 exception_content = exception.response.json() 170 else: 171 return False 172 except JSONDecodeError: 173 return False 174 return ( 175 exception.response.status_code in self._refresh_token_error_status_codes 176 and exception_content.get(self._refresh_token_error_key) 177 in self._refresh_token_error_values 178 ) 179 180 @backoff.on_exception( 181 backoff.expo, 182 DefaultBackoffException, 183 on_backoff=lambda details: logger.info( 184 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 185 ), 186 max_time=300, 187 ) 188 def _make_handled_request(self) -> Any: 189 """ 190 Makes a handled HTTP request to refresh an OAuth token. 191 192 This method sends a POST request to the token refresh endpoint with the necessary 193 headers and body to obtain a new access token. It handles various exceptions that 194 may occur during the request and logs the response for troubleshooting purposes. 195 196 Returns: 197 Mapping[str, Any]: The JSON response from the token refresh endpoint. 198 199 Raises: 200 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 201 or any 5xx server error. 202 AirbyteTracedException: If the refresh token is invalid or expired, prompting 203 re-authentication. 204 Exception: For any other exceptions that occur during the request. 205 """ 206 try: 207 response = requests.request( 208 method="POST", 209 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 210 data=self.build_refresh_request_body(), 211 headers=self.build_refresh_request_headers(), 212 ) 213 # log the response even if the request failed for troubleshooting purposes 214 self._log_response(response) 215 response.raise_for_status() 216 return response.json() 217 except requests.exceptions.RequestException as e: 218 if e.response is not None: 219 if e.response.status_code == 429 or e.response.status_code >= 500: 220 raise DefaultBackoffException(request=e.response.request, response=e.response) 221 if self._wrap_refresh_token_exception(e): 222 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 223 raise AirbyteTracedException( 224 internal_message=message, message=message, failure_type=FailureType.config_error 225 ) 226 raise 227 except Exception as e: 228 raise Exception(f"Error while refreshing access token: {e}") from e 229 230 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 231 """ 232 Ensures that the access token is present in the response data. 233 234 This method attempts to extract the access token from the provided response data. 235 If the access token is not found, it raises an exception indicating that the token 236 refresh API response was missing the access token. If the access token is found, 237 it adds the token to the list of secrets to ensure it is replaced before logging 238 the response. 239 240 Args: 241 response_data (Mapping[str, Any]): The response data from which to extract the access token. 242 243 Raises: 244 Exception: If the access token is not found in the response data. 245 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 246 """ 247 try: 248 access_key = self._extract_access_token(response_data) 249 if not access_key: 250 raise Exception( 251 f"Token refresh API response was missing access token {self.get_access_token_name()}" 252 ) 253 # Add the access token to the list of secrets so it is replaced before logging the response 254 # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... 255 add_to_secrets(access_key) 256 except ResponseKeysMaxRecurtionReached as e: 257 raise e 258 259 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 260 """ 261 Return the expiration datetime of the refresh token 262 263 :return: expiration datetime 264 """ 265 if not value and not self.token_has_expired(): 266 # No expiry token was provided but the previous one is not expired so it's fine 267 return self.get_token_expiry_date() 268 269 if self.token_expiry_is_time_of_expiration: 270 if not self.token_expiry_date_format: 271 raise ValueError( 272 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 273 ) 274 try: 275 return ab_datetime_parse(str(value)) 276 except ValueError as e: 277 raise ValueError(f"Invalid token expiry date format: {e}") 278 else: 279 try: 280 # Only accept numeric values (as int/float/string) when no format specified 281 seconds = int(float(str(value))) 282 return ab_datetime_now() + timedelta(seconds=seconds) 283 except (ValueError, TypeError): 284 raise ValueError( 285 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 286 ) 287 288 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 289 """ 290 Extracts the access token from the given response data. 291 292 Args: 293 response_data (Mapping[str, Any]): The response data from which to extract the access token. 294 295 Returns: 296 str: The extracted access token. 297 """ 298 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 299 300 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 301 """ 302 Extracts the refresh token from the given response data. 303 304 Args: 305 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 306 307 Returns: 308 str: The extracted refresh token. 309 """ 310 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 311 312 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: 313 """ 314 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 315 316 Args: 317 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 318 319 Returns: 320 str: The extracted token_expiry_date. 321 """ 322 return self._find_and_get_value_from_response(response_data, self.get_expires_in_name()) 323 324 def _find_and_get_value_from_response( 325 self, 326 response_data: Mapping[str, Any], 327 key_name: str, 328 max_depth: int = 5, 329 current_depth: int = 0, 330 ) -> Any: 331 """ 332 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 333 334 Args: 335 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 336 key_name (str): The key to search for in the response data. 337 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 338 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 339 340 Returns: 341 Any: The value associated with the specified key if found, otherwise None. 342 343 Raises: 344 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 345 """ 346 if current_depth > max_depth: 347 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 348 message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response." 349 raise ResponseKeysMaxRecurtionReached( 350 internal_message=message, message=message, failure_type=FailureType.config_error 351 ) 352 353 if isinstance(response_data, dict): 354 # get from the root level 355 if key_name in response_data: 356 return response_data[key_name] 357 358 # get from the nested object 359 for _, value in response_data.items(): 360 result = self._find_and_get_value_from_response( 361 value, key_name, max_depth, current_depth + 1 362 ) 363 if result is not None: 364 return result 365 366 # get from the nested array object 367 elif isinstance(response_data, list): 368 for item in response_data: 369 result = self._find_and_get_value_from_response( 370 item, key_name, max_depth, current_depth + 1 371 ) 372 if result is not None: 373 return result 374 375 return None 376 377 @property 378 def _message_repository(self) -> Optional[MessageRepository]: 379 """ 380 The implementation can define a message_repository if it wants debugging logs for HTTP requests 381 """ 382 return _NOOP_MESSAGE_REPOSITORY 383 384 def _log_response(self, response: requests.Response) -> None: 385 """ 386 Logs the HTTP response using the message repository if it is available. 387 388 Args: 389 response (requests.Response): The HTTP response to log. 390 """ 391 if self._message_repository: 392 self._message_repository.log_message( 393 Level.DEBUG, 394 lambda: format_http_message( 395 response, 396 "Refresh token", 397 "Obtains access token", 398 self._NO_STREAM_NAME, 399 is_auxiliary=True, 400 type="AUTH", 401 ), 402 ) 403 404 # ---------------- 405 # ABSTR METHODS 406 # ---------------- 407 408 @abstractmethod 409 def get_token_refresh_endpoint(self) -> Optional[str]: 410 """Returns the endpoint to refresh the access token""" 411 412 @abstractmethod 413 def get_client_id_name(self) -> str: 414 """The client id name to authenticate""" 415 416 @abstractmethod 417 def get_client_id(self) -> str: 418 """The client id to authenticate""" 419 420 @abstractmethod 421 def get_client_secret_name(self) -> str: 422 """The client secret name to authenticate""" 423 424 @abstractmethod 425 def get_client_secret(self) -> str: 426 """The client secret to authenticate""" 427 428 @abstractmethod 429 def get_refresh_token_name(self) -> str: 430 """The refresh token name to authenticate""" 431 432 @abstractmethod 433 def get_refresh_token(self) -> Optional[str]: 434 """The token used to refresh the access token when it expires""" 435 436 @abstractmethod 437 def get_scopes(self) -> List[str]: 438 """List of requested scopes""" 439 440 @abstractmethod 441 def get_token_expiry_date(self) -> AirbyteDateTime: 442 """Expiration date of the access token""" 443 444 @abstractmethod 445 def set_token_expiry_date(self, value: Union[str, int]) -> None: 446 """Setter for access token expiration date""" 447 448 @abstractmethod 449 def get_access_token_name(self) -> str: 450 """Field to extract access token from in the response""" 451 452 @abstractmethod 453 def get_expires_in_name(self) -> str: 454 """Returns the expires_in field name""" 455 456 @abstractmethod 457 def get_refresh_request_body(self) -> Mapping[str, Any]: 458 """Returns the request body to set on the refresh request""" 459 460 @abstractmethod 461 def get_refresh_request_headers(self) -> Mapping[str, Any]: 462 """Returns the request headers to set on the refresh request""" 463 464 @abstractmethod 465 def get_grant_type(self) -> str: 466 """Returns grant_type specified for requesting access_token""" 467 468 @abstractmethod 469 def get_grant_type_name(self) -> str: 470 """Returns grant_type specified name for requesting access_token""" 471 472 @property 473 @abstractmethod 474 def access_token(self) -> str: 475 """Returns the access token""" 476 477 @access_token.setter 478 @abstractmethod 479 def access_token(self, value: str) -> str: 480 """Setter for the access token"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.
45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.
68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"}
HTTP header to set on the requests
89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token
Returns the access token
98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date()
Returns True if the token is expired
102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request 105 106 Override to define additional parameters 107 """ 108 payload: MutableMapping[str, Any] = { 109 self.get_grant_type_name(): self.get_grant_type(), 110 self.get_client_id_name(): self.get_client_id(), 111 self.get_client_secret_name(): self.get_client_secret(), 112 self.get_refresh_token_name(): self.get_refresh_token(), 113 } 114 115 if self.get_scopes(): 116 payload["scopes"] = self.get_scopes() 117 118 if self.get_refresh_request_body(): 119 for key, val in self.get_refresh_request_body().items(): 120 # We defer to existing oauth constructs over custom configured fields 121 if key not in payload: 122 payload[key] = val 123 124 return payload
Returns the request body to set on the refresh request
Override to define additional parameters
126 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 127 """ 128 Returns the request headers to set on the refresh request 129 130 """ 131 headers = self.get_refresh_request_headers() 132 return headers if headers else None
Returns the request headers to set on the refresh request
134 def refresh_access_token(self) -> Tuple[str, Union[str, int]]: 135 """ 136 Returns the refresh token and its expiration datetime 137 138 :return: a tuple of (access_token, token_lifespan) 139 """ 140 response_json = self._make_handled_request() 141 self._ensure_access_token_in_response(response_json) 142 143 return ( 144 self._extract_access_token(response_json), 145 self._extract_token_expiry_date(response_json), 146 )
Returns the refresh token and its expiration datetime
Returns
a tuple of (access_token, token_lifespan)
408 @abstractmethod 409 def get_token_refresh_endpoint(self) -> Optional[str]: 410 """Returns the endpoint to refresh the access token"""
Returns the endpoint to refresh the access token
412 @abstractmethod 413 def get_client_id_name(self) -> str: 414 """The client id name to authenticate"""
The client id name to authenticate
420 @abstractmethod 421 def get_client_secret_name(self) -> str: 422 """The client secret name to authenticate"""
The client secret name to authenticate
424 @abstractmethod 425 def get_client_secret(self) -> str: 426 """The client secret to authenticate"""
The client secret to authenticate
428 @abstractmethod 429 def get_refresh_token_name(self) -> str: 430 """The refresh token name to authenticate"""
The refresh token name to authenticate
432 @abstractmethod 433 def get_refresh_token(self) -> Optional[str]: 434 """The token used to refresh the access token when it expires"""
The token used to refresh the access token when it expires
440 @abstractmethod 441 def get_token_expiry_date(self) -> AirbyteDateTime: 442 """Expiration date of the access token"""
Expiration date of the access token
444 @abstractmethod 445 def set_token_expiry_date(self, value: Union[str, int]) -> None: 446 """Setter for access token expiration date"""
Setter for access token expiration date
448 @abstractmethod 449 def get_access_token_name(self) -> str: 450 """Field to extract access token from in the response"""
Field to extract access token from in the response
452 @abstractmethod 453 def get_expires_in_name(self) -> str: 454 """Returns the expires_in field name"""
Returns the expires_in field name
456 @abstractmethod 457 def get_refresh_request_body(self) -> Mapping[str, Any]: 458 """Returns the request body to set on the refresh request"""
Returns the request body to set on the refresh request
460 @abstractmethod 461 def get_refresh_request_headers(self) -> Mapping[str, Any]: 462 """Returns the request headers to set on the refresh request"""
Returns the request headers to set on the refresh request
464 @abstractmethod 465 def get_grant_type(self) -> str: 466 """Returns grant_type specified for requesting access_token"""
Returns grant_type specified for requesting access_token