airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth
1# 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 3# 4 5import logging 6from abc import abstractmethod 7from datetime import timedelta 8from json import JSONDecodeError 9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union 10 11import backoff 12import requests 13from requests.auth import AuthBase 14 15from airbyte_cdk.models import FailureType, Level 16from airbyte_cdk.sources.http_logger import format_http_message 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository 18from airbyte_cdk.utils import AirbyteTracedException 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse 21 22from ..exceptions import DefaultBackoffException 23 24logger = logging.getLogger("airbyte") 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() 26 27 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 29 """ 30 Raised when the max level of recursion is reached, when trying to 31 find-and-get the target key, during the `_make_handled_request` 32 """ 33 34 35class AbstractOauth2Authenticator(AuthBase): 36 """ 37 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 38 is designed to generically perform the refresh flow without regard to how config fields are get/set by 39 delegating that behavior to the classes implementing the interface. 40 """ 41 42 _NO_STREAM_NAME = None 43 44 def __init__( 45 self, 46 refresh_token_error_status_codes: Tuple[int, ...] = (), 47 refresh_token_error_key: str = "", 48 refresh_token_error_values: Tuple[str, ...] = (), 49 ) -> None: 50 """ 51 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 52 then http errors with such params will be wrapped in AirbyteTracedException. 53 """ 54 self._refresh_token_error_status_codes = refresh_token_error_status_codes 55 self._refresh_token_error_key = refresh_token_error_key 56 self._refresh_token_error_values = refresh_token_error_values 57 58 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 59 """Attach the HTTP headers required to authenticate on the HTTP request""" 60 request.headers.update(self.get_auth_header()) 61 return request 62 63 @property 64 def _is_access_token_flow(self) -> bool: 65 return self.get_token_refresh_endpoint() is None and self.access_token is not None 66 67 @property 68 def token_expiry_is_time_of_expiration(self) -> bool: 69 """ 70 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 71 """ 72 73 return False 74 75 @property 76 def token_expiry_date_format(self) -> Optional[str]: 77 """ 78 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 79 """ 80 81 return None 82 83 def get_auth_header(self) -> Mapping[str, Any]: 84 """HTTP header to set on the requests""" 85 token = self.access_token if self._is_access_token_flow else self.get_access_token() 86 return {"Authorization": f"Bearer {token}"} 87 88 def get_access_token(self) -> str: 89 """Returns the access token""" 90 if self.token_has_expired(): 91 token, expires_in = self.refresh_access_token() 92 self.access_token = token 93 self.set_token_expiry_date(expires_in) 94 95 return self.access_token 96 97 def token_has_expired(self) -> bool: 98 """Returns True if the token is expired""" 99 return ab_datetime_now() > self.get_token_expiry_date() 100 101 def build_refresh_request_body(self) -> Mapping[str, Any]: 102 """ 103 Returns the request body to set on the refresh request 104 105 Override to define additional parameters 106 """ 107 payload: MutableMapping[str, Any] = { 108 self.get_grant_type_name(): self.get_grant_type(), 109 self.get_client_id_name(): self.get_client_id(), 110 self.get_client_secret_name(): self.get_client_secret(), 111 self.get_refresh_token_name(): self.get_refresh_token(), 112 } 113 114 if self.get_scopes(): 115 payload["scopes"] = self.get_scopes() 116 117 if self.get_refresh_request_body(): 118 for key, val in self.get_refresh_request_body().items(): 119 # We defer to existing oauth constructs over custom configured fields 120 if key not in payload: 121 payload[key] = val 122 123 return payload 124 125 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 126 """ 127 Returns the request headers to set on the refresh request 128 129 """ 130 headers = self.get_refresh_request_headers() 131 return headers if headers else None 132 133 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 134 """ 135 Returns the refresh token and its expiration datetime 136 137 :return: a tuple of (access_token, token_lifespan) 138 """ 139 response_json = self._make_handled_request() 140 self._ensure_access_token_in_response(response_json) 141 142 return ( 143 self._extract_access_token(response_json), 144 self._extract_token_expiry_date(response_json), 145 ) 146 147 # ---------------- 148 # PRIVATE METHODS 149 # ---------------- 150 151 def _default_token_expiry_date(self) -> AirbyteDateTime: 152 """ 153 Returns the default token expiry date 154 """ 155 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 156 default_token_expiry_duration_hours = 1 # 1 hour 157 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 158 159 def _wrap_refresh_token_exception( 160 self, exception: requests.exceptions.RequestException 161 ) -> bool: 162 """ 163 Wraps and handles exceptions that occur during the refresh token process. 164 165 This method checks if the provided exception is related to a refresh token error 166 by examining the response status code and specific error content. 167 168 Args: 169 exception (requests.exceptions.RequestException): The exception raised during the request. 170 171 Returns: 172 bool: True if the exception is related to a refresh token error, False otherwise. 173 """ 174 try: 175 if exception.response is not None: 176 exception_content = exception.response.json() 177 else: 178 return False 179 except JSONDecodeError: 180 return False 181 return ( 182 exception.response.status_code in self._refresh_token_error_status_codes 183 and exception_content.get(self._refresh_token_error_key) 184 in self._refresh_token_error_values 185 ) 186 187 @backoff.on_exception( 188 backoff.expo, 189 DefaultBackoffException, 190 on_backoff=lambda details: logger.info( 191 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 192 ), 193 max_time=300, 194 ) 195 def _make_handled_request(self) -> Any: 196 """ 197 Makes a handled HTTP request to refresh an OAuth token. 198 199 This method sends a POST request to the token refresh endpoint with the necessary 200 headers and body to obtain a new access token. It handles various exceptions that 201 may occur during the request and logs the response for troubleshooting purposes. 202 203 Returns: 204 Mapping[str, Any]: The JSON response from the token refresh endpoint. 205 206 Raises: 207 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 208 or any 5xx server error. 209 AirbyteTracedException: If the refresh token is invalid or expired, prompting 210 re-authentication. 211 Exception: For any other exceptions that occur during the request. 212 """ 213 try: 214 response = requests.request( 215 method="POST", 216 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 217 data=self.build_refresh_request_body(), 218 headers=self.build_refresh_request_headers(), 219 ) 220 221 if not response.ok: 222 # log the response even if the request failed for troubleshooting purposes 223 self._log_response(response) 224 response.raise_for_status() 225 226 response_json = response.json() 227 228 try: 229 # extract the access token and add to secrets to avoid logging the raw value 230 access_key = self._extract_access_token(response_json) 231 if access_key: 232 add_to_secrets(access_key) 233 except ResponseKeysMaxRecurtionReached as e: 234 # could not find the access token in the response, so do nothing 235 pass 236 237 self._log_response(response) 238 239 return response_json 240 except requests.exceptions.RequestException as e: 241 if e.response is not None: 242 if e.response.status_code == 429 or e.response.status_code >= 500: 243 raise DefaultBackoffException(request=e.response.request, response=e.response) 244 if self._wrap_refresh_token_exception(e): 245 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 246 raise AirbyteTracedException( 247 internal_message=message, message=message, failure_type=FailureType.config_error 248 ) 249 raise 250 except Exception as e: 251 raise Exception(f"Error while refreshing access token: {e}") from e 252 253 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 254 """ 255 Ensures that the access token is present in the response data. 256 257 This method attempts to extract the access token from the provided response data. 258 If the access token is not found, it raises an exception indicating that the token 259 refresh API response was missing the access token. 260 261 Args: 262 response_data (Mapping[str, Any]): The response data from which to extract the access token. 263 264 Raises: 265 Exception: If the access token is not found in the response data. 266 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 267 """ 268 try: 269 access_key = self._extract_access_token(response_data) 270 if not access_key: 271 raise Exception( 272 f"Token refresh API response was missing access token {self.get_access_token_name()}" 273 ) 274 except ResponseKeysMaxRecurtionReached as e: 275 raise e 276 277 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 278 """ 279 Parse a string or integer token expiration date into a datetime object 280 281 :return: expiration datetime 282 """ 283 if self.token_expiry_is_time_of_expiration: 284 if not self.token_expiry_date_format: 285 raise ValueError( 286 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 287 ) 288 try: 289 return ab_datetime_parse(str(value)) 290 except ValueError as e: 291 raise ValueError(f"Invalid token expiry date format: {e}") 292 else: 293 try: 294 # Only accept numeric values (as int/float/string) when no format specified 295 seconds = int(float(str(value))) 296 return ab_datetime_now() + timedelta(seconds=seconds) 297 except (ValueError, TypeError): 298 raise ValueError( 299 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 300 ) 301 302 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 303 """ 304 Extracts the access token from the given response data. 305 306 Args: 307 response_data (Mapping[str, Any]): The response data from which to extract the access token. 308 309 Returns: 310 str: The extracted access token. 311 """ 312 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 313 314 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 315 """ 316 Extracts the refresh token from the given response data. 317 318 Args: 319 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 320 321 Returns: 322 str: The extracted refresh token. 323 """ 324 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 325 326 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 327 """ 328 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 329 330 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 331 332 Args: 333 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 334 335 Returns: 336 The extracted token_expiry_date or None if not found. 337 """ 338 expires_in = self._find_and_get_value_from_response( 339 response_data, self.get_expires_in_name() 340 ) 341 if expires_in is not None: 342 return self._parse_token_expiration_date(expires_in) 343 344 # expires_in is None 345 existing_expiry_date = self.get_token_expiry_date() 346 if existing_expiry_date and not self.token_has_expired(): 347 return existing_expiry_date 348 349 return self._default_token_expiry_date() 350 351 def _find_and_get_value_from_response( 352 self, 353 response_data: Mapping[str, Any], 354 key_name: str, 355 max_depth: int = 5, 356 current_depth: int = 0, 357 ) -> Any: 358 """ 359 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 360 361 Args: 362 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 363 key_name (str): The key to search for in the response data. 364 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 365 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 366 367 Returns: 368 Any: The value associated with the specified key if found, otherwise None. 369 370 Raises: 371 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 372 """ 373 if current_depth > max_depth: 374 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 375 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 376 raise ResponseKeysMaxRecurtionReached( 377 internal_message=message, message=message, failure_type=FailureType.config_error 378 ) 379 380 if isinstance(response_data, dict): 381 # get from the root level 382 if key_name in response_data: 383 return response_data[key_name] 384 385 # get from the nested object 386 for _, value in response_data.items(): 387 result = self._find_and_get_value_from_response( 388 value, key_name, max_depth, current_depth + 1 389 ) 390 if result is not None: 391 return result 392 393 # get from the nested array object 394 elif isinstance(response_data, list): 395 for item in response_data: 396 result = self._find_and_get_value_from_response( 397 item, key_name, max_depth, current_depth + 1 398 ) 399 if result is not None: 400 return result 401 402 return None 403 404 @property 405 def _message_repository(self) -> Optional[MessageRepository]: 406 """ 407 The implementation can define a message_repository if it wants debugging logs for HTTP requests 408 """ 409 return _NOOP_MESSAGE_REPOSITORY 410 411 def _log_response(self, response: requests.Response) -> None: 412 """ 413 Logs the HTTP response using the message repository if it is available. 414 415 Args: 416 response (requests.Response): The HTTP response to log. 417 """ 418 if self._message_repository: 419 self._message_repository.log_message( 420 Level.DEBUG, 421 lambda: format_http_message( 422 response, 423 "Refresh token", 424 "Obtains access token", 425 self._NO_STREAM_NAME, 426 is_auxiliary=True, 427 type="AUTH", 428 ), 429 ) 430 431 # ---------------- 432 # ABSTR METHODS 433 # ---------------- 434 435 @abstractmethod 436 def get_token_refresh_endpoint(self) -> Optional[str]: 437 """Returns the endpoint to refresh the access token""" 438 439 @abstractmethod 440 def get_client_id_name(self) -> str: 441 """The client id name to authenticate""" 442 443 @abstractmethod 444 def get_client_id(self) -> str: 445 """The client id to authenticate""" 446 447 @abstractmethod 448 def get_client_secret_name(self) -> str: 449 """The client secret name to authenticate""" 450 451 @abstractmethod 452 def get_client_secret(self) -> str: 453 """The client secret to authenticate""" 454 455 @abstractmethod 456 def get_refresh_token_name(self) -> str: 457 """The refresh token name to authenticate""" 458 459 @abstractmethod 460 def get_refresh_token(self) -> Optional[str]: 461 """The token used to refresh the access token when it expires""" 462 463 @abstractmethod 464 def get_scopes(self) -> List[str]: 465 """List of requested scopes""" 466 467 @abstractmethod 468 def get_token_expiry_date(self) -> AirbyteDateTime: 469 """Expiration date of the access token""" 470 471 @abstractmethod 472 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 473 """Setter for access token expiration date""" 474 475 @abstractmethod 476 def get_access_token_name(self) -> str: 477 """Field to extract access token from in the response""" 478 479 @abstractmethod 480 def get_expires_in_name(self) -> str: 481 """Returns the expires_in field name""" 482 483 @abstractmethod 484 def get_refresh_request_body(self) -> Mapping[str, Any]: 485 """Returns the request body to set on the refresh request""" 486 487 @abstractmethod 488 def get_refresh_request_headers(self) -> Mapping[str, Any]: 489 """Returns the request headers to set on the refresh request""" 490 491 @abstractmethod 492 def get_grant_type(self) -> str: 493 """Returns grant_type specified for requesting access_token""" 494 495 @abstractmethod 496 def get_grant_type_name(self) -> str: 497 """Returns grant_type specified name for requesting access_token""" 498 499 @property 500 @abstractmethod 501 def access_token(self) -> str: 502 """Returns the access token""" 503 504 @access_token.setter 505 @abstractmethod 506 def access_token(self, value: str) -> str: 507 """Setter for the access token"""
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 30 """ 31 Raised when the max level of recursion is reached, when trying to 32 find-and-get the target key, during the `_make_handled_request` 33 """
Raised when the max level of recursion is reached, when trying to
find-and-get the target key, during the _make_handled_request
36class AbstractOauth2Authenticator(AuthBase): 37 """ 38 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 39 is designed to generically perform the refresh flow without regard to how config fields are get/set by 40 delegating that behavior to the classes implementing the interface. 41 """ 42 43 _NO_STREAM_NAME = None 44 45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values 58 59 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 60 """Attach the HTTP headers required to authenticate on the HTTP request""" 61 request.headers.update(self.get_auth_header()) 62 return request 63 64 @property 65 def _is_access_token_flow(self) -> bool: 66 return self.get_token_refresh_endpoint() is None and self.access_token is not None 67 68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False 75 76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None 83 84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"} 88 89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token 97 98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date() 101 102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request 105 106 Override to define additional parameters 107 """ 108 payload: MutableMapping[str, Any] = { 109 self.get_grant_type_name(): self.get_grant_type(), 110 self.get_client_id_name(): self.get_client_id(), 111 self.get_client_secret_name(): self.get_client_secret(), 112 self.get_refresh_token_name(): self.get_refresh_token(), 113 } 114 115 if self.get_scopes(): 116 payload["scopes"] = self.get_scopes() 117 118 if self.get_refresh_request_body(): 119 for key, val in self.get_refresh_request_body().items(): 120 # We defer to existing oauth constructs over custom configured fields 121 if key not in payload: 122 payload[key] = val 123 124 return payload 125 126 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 127 """ 128 Returns the request headers to set on the refresh request 129 130 """ 131 headers = self.get_refresh_request_headers() 132 return headers if headers else None 133 134 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 135 """ 136 Returns the refresh token and its expiration datetime 137 138 :return: a tuple of (access_token, token_lifespan) 139 """ 140 response_json = self._make_handled_request() 141 self._ensure_access_token_in_response(response_json) 142 143 return ( 144 self._extract_access_token(response_json), 145 self._extract_token_expiry_date(response_json), 146 ) 147 148 # ---------------- 149 # PRIVATE METHODS 150 # ---------------- 151 152 def _default_token_expiry_date(self) -> AirbyteDateTime: 153 """ 154 Returns the default token expiry date 155 """ 156 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 157 default_token_expiry_duration_hours = 1 # 1 hour 158 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 159 160 def _wrap_refresh_token_exception( 161 self, exception: requests.exceptions.RequestException 162 ) -> bool: 163 """ 164 Wraps and handles exceptions that occur during the refresh token process. 165 166 This method checks if the provided exception is related to a refresh token error 167 by examining the response status code and specific error content. 168 169 Args: 170 exception (requests.exceptions.RequestException): The exception raised during the request. 171 172 Returns: 173 bool: True if the exception is related to a refresh token error, False otherwise. 174 """ 175 try: 176 if exception.response is not None: 177 exception_content = exception.response.json() 178 else: 179 return False 180 except JSONDecodeError: 181 return False 182 return ( 183 exception.response.status_code in self._refresh_token_error_status_codes 184 and exception_content.get(self._refresh_token_error_key) 185 in self._refresh_token_error_values 186 ) 187 188 @backoff.on_exception( 189 backoff.expo, 190 DefaultBackoffException, 191 on_backoff=lambda details: logger.info( 192 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 193 ), 194 max_time=300, 195 ) 196 def _make_handled_request(self) -> Any: 197 """ 198 Makes a handled HTTP request to refresh an OAuth token. 199 200 This method sends a POST request to the token refresh endpoint with the necessary 201 headers and body to obtain a new access token. It handles various exceptions that 202 may occur during the request and logs the response for troubleshooting purposes. 203 204 Returns: 205 Mapping[str, Any]: The JSON response from the token refresh endpoint. 206 207 Raises: 208 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 209 or any 5xx server error. 210 AirbyteTracedException: If the refresh token is invalid or expired, prompting 211 re-authentication. 212 Exception: For any other exceptions that occur during the request. 213 """ 214 try: 215 response = requests.request( 216 method="POST", 217 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 218 data=self.build_refresh_request_body(), 219 headers=self.build_refresh_request_headers(), 220 ) 221 222 if not response.ok: 223 # log the response even if the request failed for troubleshooting purposes 224 self._log_response(response) 225 response.raise_for_status() 226 227 response_json = response.json() 228 229 try: 230 # extract the access token and add to secrets to avoid logging the raw value 231 access_key = self._extract_access_token(response_json) 232 if access_key: 233 add_to_secrets(access_key) 234 except ResponseKeysMaxRecurtionReached as e: 235 # could not find the access token in the response, so do nothing 236 pass 237 238 self._log_response(response) 239 240 return response_json 241 except requests.exceptions.RequestException as e: 242 if e.response is not None: 243 if e.response.status_code == 429 or e.response.status_code >= 500: 244 raise DefaultBackoffException(request=e.response.request, response=e.response) 245 if self._wrap_refresh_token_exception(e): 246 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 247 raise AirbyteTracedException( 248 internal_message=message, message=message, failure_type=FailureType.config_error 249 ) 250 raise 251 except Exception as e: 252 raise Exception(f"Error while refreshing access token: {e}") from e 253 254 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 255 """ 256 Ensures that the access token is present in the response data. 257 258 This method attempts to extract the access token from the provided response data. 259 If the access token is not found, it raises an exception indicating that the token 260 refresh API response was missing the access token. 261 262 Args: 263 response_data (Mapping[str, Any]): The response data from which to extract the access token. 264 265 Raises: 266 Exception: If the access token is not found in the response data. 267 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 268 """ 269 try: 270 access_key = self._extract_access_token(response_data) 271 if not access_key: 272 raise Exception( 273 f"Token refresh API response was missing access token {self.get_access_token_name()}" 274 ) 275 except ResponseKeysMaxRecurtionReached as e: 276 raise e 277 278 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 279 """ 280 Parse a string or integer token expiration date into a datetime object 281 282 :return: expiration datetime 283 """ 284 if self.token_expiry_is_time_of_expiration: 285 if not self.token_expiry_date_format: 286 raise ValueError( 287 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 288 ) 289 try: 290 return ab_datetime_parse(str(value)) 291 except ValueError as e: 292 raise ValueError(f"Invalid token expiry date format: {e}") 293 else: 294 try: 295 # Only accept numeric values (as int/float/string) when no format specified 296 seconds = int(float(str(value))) 297 return ab_datetime_now() + timedelta(seconds=seconds) 298 except (ValueError, TypeError): 299 raise ValueError( 300 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 301 ) 302 303 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 304 """ 305 Extracts the access token from the given response data. 306 307 Args: 308 response_data (Mapping[str, Any]): The response data from which to extract the access token. 309 310 Returns: 311 str: The extracted access token. 312 """ 313 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 314 315 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 316 """ 317 Extracts the refresh token from the given response data. 318 319 Args: 320 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 321 322 Returns: 323 str: The extracted refresh token. 324 """ 325 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 326 327 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 328 """ 329 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 330 331 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 332 333 Args: 334 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 335 336 Returns: 337 The extracted token_expiry_date or None if not found. 338 """ 339 expires_in = self._find_and_get_value_from_response( 340 response_data, self.get_expires_in_name() 341 ) 342 if expires_in is not None: 343 return self._parse_token_expiration_date(expires_in) 344 345 # expires_in is None 346 existing_expiry_date = self.get_token_expiry_date() 347 if existing_expiry_date and not self.token_has_expired(): 348 return existing_expiry_date 349 350 return self._default_token_expiry_date() 351 352 def _find_and_get_value_from_response( 353 self, 354 response_data: Mapping[str, Any], 355 key_name: str, 356 max_depth: int = 5, 357 current_depth: int = 0, 358 ) -> Any: 359 """ 360 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 361 362 Args: 363 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 364 key_name (str): The key to search for in the response data. 365 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 366 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 367 368 Returns: 369 Any: The value associated with the specified key if found, otherwise None. 370 371 Raises: 372 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 373 """ 374 if current_depth > max_depth: 375 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 376 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 377 raise ResponseKeysMaxRecurtionReached( 378 internal_message=message, message=message, failure_type=FailureType.config_error 379 ) 380 381 if isinstance(response_data, dict): 382 # get from the root level 383 if key_name in response_data: 384 return response_data[key_name] 385 386 # get from the nested object 387 for _, value in response_data.items(): 388 result = self._find_and_get_value_from_response( 389 value, key_name, max_depth, current_depth + 1 390 ) 391 if result is not None: 392 return result 393 394 # get from the nested array object 395 elif isinstance(response_data, list): 396 for item in response_data: 397 result = self._find_and_get_value_from_response( 398 item, key_name, max_depth, current_depth + 1 399 ) 400 if result is not None: 401 return result 402 403 return None 404 405 @property 406 def _message_repository(self) -> Optional[MessageRepository]: 407 """ 408 The implementation can define a message_repository if it wants debugging logs for HTTP requests 409 """ 410 return _NOOP_MESSAGE_REPOSITORY 411 412 def _log_response(self, response: requests.Response) -> None: 413 """ 414 Logs the HTTP response using the message repository if it is available. 415 416 Args: 417 response (requests.Response): The HTTP response to log. 418 """ 419 if self._message_repository: 420 self._message_repository.log_message( 421 Level.DEBUG, 422 lambda: format_http_message( 423 response, 424 "Refresh token", 425 "Obtains access token", 426 self._NO_STREAM_NAME, 427 is_auxiliary=True, 428 type="AUTH", 429 ), 430 ) 431 432 # ---------------- 433 # ABSTR METHODS 434 # ---------------- 435 436 @abstractmethod 437 def get_token_refresh_endpoint(self) -> Optional[str]: 438 """Returns the endpoint to refresh the access token""" 439 440 @abstractmethod 441 def get_client_id_name(self) -> str: 442 """The client id name to authenticate""" 443 444 @abstractmethod 445 def get_client_id(self) -> str: 446 """The client id to authenticate""" 447 448 @abstractmethod 449 def get_client_secret_name(self) -> str: 450 """The client secret name to authenticate""" 451 452 @abstractmethod 453 def get_client_secret(self) -> str: 454 """The client secret to authenticate""" 455 456 @abstractmethod 457 def get_refresh_token_name(self) -> str: 458 """The refresh token name to authenticate""" 459 460 @abstractmethod 461 def get_refresh_token(self) -> Optional[str]: 462 """The token used to refresh the access token when it expires""" 463 464 @abstractmethod 465 def get_scopes(self) -> List[str]: 466 """List of requested scopes""" 467 468 @abstractmethod 469 def get_token_expiry_date(self) -> AirbyteDateTime: 470 """Expiration date of the access token""" 471 472 @abstractmethod 473 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 474 """Setter for access token expiration date""" 475 476 @abstractmethod 477 def get_access_token_name(self) -> str: 478 """Field to extract access token from in the response""" 479 480 @abstractmethod 481 def get_expires_in_name(self) -> str: 482 """Returns the expires_in field name""" 483 484 @abstractmethod 485 def get_refresh_request_body(self) -> Mapping[str, Any]: 486 """Returns the request body to set on the refresh request""" 487 488 @abstractmethod 489 def get_refresh_request_headers(self) -> Mapping[str, Any]: 490 """Returns the request headers to set on the refresh request""" 491 492 @abstractmethod 493 def get_grant_type(self) -> str: 494 """Returns grant_type specified for requesting access_token""" 495 496 @abstractmethod 497 def get_grant_type_name(self) -> str: 498 """Returns grant_type specified name for requesting access_token""" 499 500 @property 501 @abstractmethod 502 def access_token(self) -> str: 503 """Returns the access token""" 504 505 @access_token.setter 506 @abstractmethod 507 def access_token(self, value: str) -> str: 508 """Setter for the access token"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.
45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.
68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"}
HTTP header to set on the requests
89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token
Returns the access token
98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date()
Returns True if the token is expired
102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request 105 106 Override to define additional parameters 107 """ 108 payload: MutableMapping[str, Any] = { 109 self.get_grant_type_name(): self.get_grant_type(), 110 self.get_client_id_name(): self.get_client_id(), 111 self.get_client_secret_name(): self.get_client_secret(), 112 self.get_refresh_token_name(): self.get_refresh_token(), 113 } 114 115 if self.get_scopes(): 116 payload["scopes"] = self.get_scopes() 117 118 if self.get_refresh_request_body(): 119 for key, val in self.get_refresh_request_body().items(): 120 # We defer to existing oauth constructs over custom configured fields 121 if key not in payload: 122 payload[key] = val 123 124 return payload
Returns the request body to set on the refresh request
Override to define additional parameters
126 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 127 """ 128 Returns the request headers to set on the refresh request 129 130 """ 131 headers = self.get_refresh_request_headers() 132 return headers if headers else None
Returns the request headers to set on the refresh request
134 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 135 """ 136 Returns the refresh token and its expiration datetime 137 138 :return: a tuple of (access_token, token_lifespan) 139 """ 140 response_json = self._make_handled_request() 141 self._ensure_access_token_in_response(response_json) 142 143 return ( 144 self._extract_access_token(response_json), 145 self._extract_token_expiry_date(response_json), 146 )
Returns the refresh token and its expiration datetime
Returns
a tuple of (access_token, token_lifespan)
436 @abstractmethod 437 def get_token_refresh_endpoint(self) -> Optional[str]: 438 """Returns the endpoint to refresh the access token"""
Returns the endpoint to refresh the access token
440 @abstractmethod 441 def get_client_id_name(self) -> str: 442 """The client id name to authenticate"""
The client id name to authenticate
448 @abstractmethod 449 def get_client_secret_name(self) -> str: 450 """The client secret name to authenticate"""
The client secret name to authenticate
452 @abstractmethod 453 def get_client_secret(self) -> str: 454 """The client secret to authenticate"""
The client secret to authenticate
456 @abstractmethod 457 def get_refresh_token_name(self) -> str: 458 """The refresh token name to authenticate"""
The refresh token name to authenticate
460 @abstractmethod 461 def get_refresh_token(self) -> Optional[str]: 462 """The token used to refresh the access token when it expires"""
The token used to refresh the access token when it expires
468 @abstractmethod 469 def get_token_expiry_date(self) -> AirbyteDateTime: 470 """Expiration date of the access token"""
Expiration date of the access token
472 @abstractmethod 473 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 474 """Setter for access token expiration date"""
Setter for access token expiration date
476 @abstractmethod 477 def get_access_token_name(self) -> str: 478 """Field to extract access token from in the response"""
Field to extract access token from in the response
480 @abstractmethod 481 def get_expires_in_name(self) -> str: 482 """Returns the expires_in field name"""
Returns the expires_in field name
484 @abstractmethod 485 def get_refresh_request_body(self) -> Mapping[str, Any]: 486 """Returns the request body to set on the refresh request"""
Returns the request body to set on the refresh request
488 @abstractmethod 489 def get_refresh_request_headers(self) -> Mapping[str, Any]: 490 """Returns the request headers to set on the refresh request"""
Returns the request headers to set on the refresh request
492 @abstractmethod 493 def get_grant_type(self) -> str: 494 """Returns grant_type specified for requesting access_token"""
Returns grant_type specified for requesting access_token