airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth
1# 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 3# 4 5import logging 6from abc import abstractmethod 7from datetime import timedelta 8from json import JSONDecodeError 9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union 10 11import backoff 12import requests 13from requests.auth import AuthBase 14 15from airbyte_cdk.models import FailureType, Level 16from airbyte_cdk.sources.http_logger import format_http_message 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository 18from airbyte_cdk.utils import AirbyteTracedException 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse 21 22from ..exceptions import DefaultBackoffException 23 24logger = logging.getLogger("airbyte") 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() 26 27 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 29 """ 30 Raised when the max level of recursion is reached, when trying to 31 find-and-get the target key, during the `_make_handled_request` 32 """ 33 34 35class AbstractOauth2Authenticator(AuthBase): 36 """ 37 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 38 is designed to generically perform the refresh flow without regard to how config fields are get/set by 39 delegating that behavior to the classes implementing the interface. 40 """ 41 42 _NO_STREAM_NAME = None 43 44 def __init__( 45 self, 46 refresh_token_error_status_codes: Tuple[int, ...] = (), 47 refresh_token_error_key: str = "", 48 refresh_token_error_values: Tuple[str, ...] = (), 49 ) -> None: 50 """ 51 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 52 then http errors with such params will be wrapped in AirbyteTracedException. 53 """ 54 self._refresh_token_error_status_codes = refresh_token_error_status_codes 55 self._refresh_token_error_key = refresh_token_error_key 56 self._refresh_token_error_values = refresh_token_error_values 57 58 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 59 """Attach the HTTP headers required to authenticate on the HTTP request""" 60 request.headers.update(self.get_auth_header()) 61 return request 62 63 @property 64 def _is_access_token_flow(self) -> bool: 65 return self.get_token_refresh_endpoint() is None and self.access_token is not None 66 67 @property 68 def token_expiry_is_time_of_expiration(self) -> bool: 69 """ 70 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 71 """ 72 73 return False 74 75 @property 76 def token_expiry_date_format(self) -> Optional[str]: 77 """ 78 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 79 """ 80 81 return None 82 83 def get_auth_header(self) -> Mapping[str, Any]: 84 """HTTP header to set on the requests""" 85 token = self.access_token if self._is_access_token_flow else self.get_access_token() 86 return {"Authorization": f"Bearer {token}"} 87 88 def get_access_token(self) -> str: 89 """Returns the access token""" 90 if self.token_has_expired(): 91 token, expires_in = self.refresh_access_token() 92 self.access_token = token 93 self.set_token_expiry_date(expires_in) 94 95 return self.access_token 96 97 def token_has_expired(self) -> bool: 98 """Returns True if the token is expired""" 99 return ab_datetime_now() > self.get_token_expiry_date() 100 101 def build_refresh_request_body(self) -> Mapping[str, Any]: 102 """ 103 Returns the request body to set on the refresh request 104 105 Override to define additional parameters 106 """ 107 payload: MutableMapping[str, Any] = { 108 self.get_grant_type_name(): self.get_grant_type(), 109 self.get_client_id_name(): self.get_client_id(), 110 self.get_client_secret_name(): self.get_client_secret(), 111 self.get_refresh_token_name(): self.get_refresh_token(), 112 } 113 114 if self.get_scopes(): 115 payload["scopes"] = self.get_scopes() 116 117 if self.get_refresh_request_body(): 118 for key, val in self.get_refresh_request_body().items(): 119 # We defer to existing oauth constructs over custom configured fields 120 if key not in payload: 121 payload[key] = val 122 123 return payload 124 125 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 126 """ 127 Returns the request headers to set on the refresh request 128 129 """ 130 headers = self.get_refresh_request_headers() 131 return headers if headers else None 132 133 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 134 """ 135 Returns the refresh token and its expiration datetime 136 137 :return: a tuple of (access_token, token_lifespan) 138 """ 139 response_json = self._make_handled_request() 140 self._ensure_access_token_in_response(response_json) 141 142 return ( 143 self._extract_access_token(response_json), 144 self._extract_token_expiry_date(response_json), 145 ) 146 147 # ---------------- 148 # PRIVATE METHODS 149 # ---------------- 150 151 def _default_token_expiry_date(self) -> AirbyteDateTime: 152 """ 153 Returns the default token expiry date 154 """ 155 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 156 default_token_expiry_duration_hours = 1 # 1 hour 157 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 158 159 def _wrap_refresh_token_exception( 160 self, exception: requests.exceptions.RequestException 161 ) -> bool: 162 """ 163 Wraps and handles exceptions that occur during the refresh token process. 164 165 This method checks if the provided exception is related to a refresh token error 166 by examining the response status code and specific error content. 167 168 Args: 169 exception (requests.exceptions.RequestException): The exception raised during the request. 170 171 Returns: 172 bool: True if the exception is related to a refresh token error, False otherwise. 173 """ 174 try: 175 if exception.response is not None: 176 exception_content = exception.response.json() 177 else: 178 return False 179 except JSONDecodeError: 180 return False 181 return ( 182 exception.response.status_code in self._refresh_token_error_status_codes 183 and exception_content.get(self._refresh_token_error_key) 184 in self._refresh_token_error_values 185 ) 186 187 @backoff.on_exception( 188 backoff.expo, 189 DefaultBackoffException, 190 on_backoff=lambda details: logger.info( 191 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 192 ), 193 max_time=300, 194 ) 195 def _make_handled_request(self) -> Any: 196 """ 197 Makes a handled HTTP request to refresh an OAuth token. 198 199 This method sends a POST request to the token refresh endpoint with the necessary 200 headers and body to obtain a new access token. It handles various exceptions that 201 may occur during the request and logs the response for troubleshooting purposes. 202 203 Returns: 204 Mapping[str, Any]: The JSON response from the token refresh endpoint. 205 206 Raises: 207 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 208 or any 5xx server error. 209 AirbyteTracedException: If the refresh token is invalid or expired, prompting 210 re-authentication. 211 Exception: For any other exceptions that occur during the request. 212 """ 213 try: 214 response = requests.request( 215 method="POST", 216 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 217 data=self.build_refresh_request_body(), 218 headers=self.build_refresh_request_headers(), 219 ) 220 221 if not response.ok: 222 # log the response even if the request failed for troubleshooting purposes 223 self._log_response(response) 224 response.raise_for_status() 225 226 response_json = response.json() 227 228 try: 229 # extract the access token and add to secrets to avoid logging the raw value 230 access_key = self._extract_access_token(response_json) 231 if access_key: 232 add_to_secrets(access_key) 233 except ResponseKeysMaxRecurtionReached as e: 234 # could not find the access token in the response, so do nothing 235 pass 236 237 self._log_response(response) 238 239 return response_json 240 except requests.exceptions.RequestException as e: 241 if e.response is not None: 242 if e.response.status_code == 429 or e.response.status_code >= 500: 243 raise DefaultBackoffException( 244 request=e.response.request, 245 response=e.response, 246 failure_type=FailureType.transient_error, 247 ) 248 if self._wrap_refresh_token_exception(e): 249 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 250 raise AirbyteTracedException( 251 internal_message=message, message=message, failure_type=FailureType.config_error 252 ) 253 raise 254 except Exception as e: 255 raise Exception(f"Error while refreshing access token: {e}") from e 256 257 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 258 """ 259 Ensures that the access token is present in the response data. 260 261 This method attempts to extract the access token from the provided response data. 262 If the access token is not found, it raises an exception indicating that the token 263 refresh API response was missing the access token. 264 265 Args: 266 response_data (Mapping[str, Any]): The response data from which to extract the access token. 267 268 Raises: 269 Exception: If the access token is not found in the response data. 270 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 271 """ 272 try: 273 access_key = self._extract_access_token(response_data) 274 if not access_key: 275 raise Exception( 276 f"Token refresh API response was missing access token {self.get_access_token_name()}" 277 ) 278 except ResponseKeysMaxRecurtionReached as e: 279 raise e 280 281 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 282 """ 283 Parse a string or integer token expiration date into a datetime object 284 285 :return: expiration datetime 286 """ 287 if self.token_expiry_is_time_of_expiration: 288 if not self.token_expiry_date_format: 289 raise ValueError( 290 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 291 ) 292 try: 293 return ab_datetime_parse(str(value)) 294 except ValueError as e: 295 raise ValueError(f"Invalid token expiry date format: {e}") 296 else: 297 try: 298 # Only accept numeric values (as int/float/string) when no format specified 299 seconds = int(float(str(value))) 300 return ab_datetime_now() + timedelta(seconds=seconds) 301 except (ValueError, TypeError): 302 raise ValueError( 303 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 304 ) 305 306 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 307 """ 308 Extracts the access token from the given response data. 309 310 Args: 311 response_data (Mapping[str, Any]): The response data from which to extract the access token. 312 313 Returns: 314 str: The extracted access token. 315 """ 316 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 317 318 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 319 """ 320 Extracts the refresh token from the given response data. 321 322 Args: 323 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 324 325 Returns: 326 str: The extracted refresh token. 327 """ 328 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 329 330 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 331 """ 332 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 333 334 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 335 336 Args: 337 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 338 339 Returns: 340 The extracted token_expiry_date or None if not found. 341 """ 342 expires_in = self._find_and_get_value_from_response( 343 response_data, self.get_expires_in_name() 344 ) 345 if expires_in is not None: 346 return self._parse_token_expiration_date(expires_in) 347 348 # expires_in is None 349 existing_expiry_date = self.get_token_expiry_date() 350 if existing_expiry_date and not self.token_has_expired(): 351 return existing_expiry_date 352 353 return self._default_token_expiry_date() 354 355 def _find_and_get_value_from_response( 356 self, 357 response_data: Mapping[str, Any], 358 key_name: str, 359 max_depth: int = 5, 360 current_depth: int = 0, 361 ) -> Any: 362 """ 363 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 364 365 Args: 366 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 367 key_name (str): The key to search for in the response data. 368 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 369 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 370 371 Returns: 372 Any: The value associated with the specified key if found, otherwise None. 373 374 Raises: 375 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 376 """ 377 if current_depth > max_depth: 378 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 379 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 380 raise ResponseKeysMaxRecurtionReached( 381 internal_message=message, message=message, failure_type=FailureType.config_error 382 ) 383 384 if isinstance(response_data, dict): 385 # get from the root level 386 if key_name in response_data: 387 return response_data[key_name] 388 389 # get from the nested object 390 for _, value in response_data.items(): 391 result = self._find_and_get_value_from_response( 392 value, key_name, max_depth, current_depth + 1 393 ) 394 if result is not None: 395 return result 396 397 # get from the nested array object 398 elif isinstance(response_data, list): 399 for item in response_data: 400 result = self._find_and_get_value_from_response( 401 item, key_name, max_depth, current_depth + 1 402 ) 403 if result is not None: 404 return result 405 406 return None 407 408 @property 409 def _message_repository(self) -> Optional[MessageRepository]: 410 """ 411 The implementation can define a message_repository if it wants debugging logs for HTTP requests 412 """ 413 return _NOOP_MESSAGE_REPOSITORY 414 415 def _log_response(self, response: requests.Response) -> None: 416 """ 417 Logs the HTTP response using the message repository if it is available. 418 419 Args: 420 response (requests.Response): The HTTP response to log. 421 """ 422 if self._message_repository: 423 self._message_repository.log_message( 424 Level.DEBUG, 425 lambda: format_http_message( 426 response, 427 "Refresh token", 428 "Obtains access token", 429 self._NO_STREAM_NAME, 430 is_auxiliary=True, 431 type="AUTH", 432 ), 433 ) 434 435 # ---------------- 436 # ABSTR METHODS 437 # ---------------- 438 439 @abstractmethod 440 def get_token_refresh_endpoint(self) -> Optional[str]: 441 """Returns the endpoint to refresh the access token""" 442 443 @abstractmethod 444 def get_client_id_name(self) -> str: 445 """The client id name to authenticate""" 446 447 @abstractmethod 448 def get_client_id(self) -> str: 449 """The client id to authenticate""" 450 451 @abstractmethod 452 def get_client_secret_name(self) -> str: 453 """The client secret name to authenticate""" 454 455 @abstractmethod 456 def get_client_secret(self) -> str: 457 """The client secret to authenticate""" 458 459 @abstractmethod 460 def get_refresh_token_name(self) -> str: 461 """The refresh token name to authenticate""" 462 463 @abstractmethod 464 def get_refresh_token(self) -> Optional[str]: 465 """The token used to refresh the access token when it expires""" 466 467 @abstractmethod 468 def get_scopes(self) -> List[str]: 469 """List of requested scopes""" 470 471 @abstractmethod 472 def get_token_expiry_date(self) -> AirbyteDateTime: 473 """Expiration date of the access token""" 474 475 @abstractmethod 476 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 477 """Setter for access token expiration date""" 478 479 @abstractmethod 480 def get_access_token_name(self) -> str: 481 """Field to extract access token from in the response""" 482 483 @abstractmethod 484 def get_expires_in_name(self) -> str: 485 """Returns the expires_in field name""" 486 487 @abstractmethod 488 def get_refresh_request_body(self) -> Mapping[str, Any]: 489 """Returns the request body to set on the refresh request""" 490 491 @abstractmethod 492 def get_refresh_request_headers(self) -> Mapping[str, Any]: 493 """Returns the request headers to set on the refresh request""" 494 495 @abstractmethod 496 def get_grant_type(self) -> str: 497 """Returns grant_type specified for requesting access_token""" 498 499 @abstractmethod 500 def get_grant_type_name(self) -> str: 501 """Returns grant_type specified name for requesting access_token""" 502 503 @property 504 @abstractmethod 505 def access_token(self) -> str: 506 """Returns the access token""" 507 508 @access_token.setter 509 @abstractmethod 510 def access_token(self, value: str) -> str: 511 """Setter for the access token"""
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 30 """ 31 Raised when the max level of recursion is reached, when trying to 32 find-and-get the target key, during the `_make_handled_request` 33 """
Raised when the max level of recursion is reached, when trying to
find-and-get the target key, during the _make_handled_request
36class AbstractOauth2Authenticator(AuthBase): 37 """ 38 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 39 is designed to generically perform the refresh flow without regard to how config fields are get/set by 40 delegating that behavior to the classes implementing the interface. 41 """ 42 43 _NO_STREAM_NAME = None 44 45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values 58 59 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 60 """Attach the HTTP headers required to authenticate on the HTTP request""" 61 request.headers.update(self.get_auth_header()) 62 return request 63 64 @property 65 def _is_access_token_flow(self) -> bool: 66 return self.get_token_refresh_endpoint() is None and self.access_token is not None 67 68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False 75 76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None 83 84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"} 88 89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token 97 98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date() 101 102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request 105 106 Override to define additional parameters 107 """ 108 payload: MutableMapping[str, Any] = { 109 self.get_grant_type_name(): self.get_grant_type(), 110 self.get_client_id_name(): self.get_client_id(), 111 self.get_client_secret_name(): self.get_client_secret(), 112 self.get_refresh_token_name(): self.get_refresh_token(), 113 } 114 115 if self.get_scopes(): 116 payload["scopes"] = self.get_scopes() 117 118 if self.get_refresh_request_body(): 119 for key, val in self.get_refresh_request_body().items(): 120 # We defer to existing oauth constructs over custom configured fields 121 if key not in payload: 122 payload[key] = val 123 124 return payload 125 126 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 127 """ 128 Returns the request headers to set on the refresh request 129 130 """ 131 headers = self.get_refresh_request_headers() 132 return headers if headers else None 133 134 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 135 """ 136 Returns the refresh token and its expiration datetime 137 138 :return: a tuple of (access_token, token_lifespan) 139 """ 140 response_json = self._make_handled_request() 141 self._ensure_access_token_in_response(response_json) 142 143 return ( 144 self._extract_access_token(response_json), 145 self._extract_token_expiry_date(response_json), 146 ) 147 148 # ---------------- 149 # PRIVATE METHODS 150 # ---------------- 151 152 def _default_token_expiry_date(self) -> AirbyteDateTime: 153 """ 154 Returns the default token expiry date 155 """ 156 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 157 default_token_expiry_duration_hours = 1 # 1 hour 158 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 159 160 def _wrap_refresh_token_exception( 161 self, exception: requests.exceptions.RequestException 162 ) -> bool: 163 """ 164 Wraps and handles exceptions that occur during the refresh token process. 165 166 This method checks if the provided exception is related to a refresh token error 167 by examining the response status code and specific error content. 168 169 Args: 170 exception (requests.exceptions.RequestException): The exception raised during the request. 171 172 Returns: 173 bool: True if the exception is related to a refresh token error, False otherwise. 174 """ 175 try: 176 if exception.response is not None: 177 exception_content = exception.response.json() 178 else: 179 return False 180 except JSONDecodeError: 181 return False 182 return ( 183 exception.response.status_code in self._refresh_token_error_status_codes 184 and exception_content.get(self._refresh_token_error_key) 185 in self._refresh_token_error_values 186 ) 187 188 @backoff.on_exception( 189 backoff.expo, 190 DefaultBackoffException, 191 on_backoff=lambda details: logger.info( 192 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 193 ), 194 max_time=300, 195 ) 196 def _make_handled_request(self) -> Any: 197 """ 198 Makes a handled HTTP request to refresh an OAuth token. 199 200 This method sends a POST request to the token refresh endpoint with the necessary 201 headers and body to obtain a new access token. It handles various exceptions that 202 may occur during the request and logs the response for troubleshooting purposes. 203 204 Returns: 205 Mapping[str, Any]: The JSON response from the token refresh endpoint. 206 207 Raises: 208 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 209 or any 5xx server error. 210 AirbyteTracedException: If the refresh token is invalid or expired, prompting 211 re-authentication. 212 Exception: For any other exceptions that occur during the request. 213 """ 214 try: 215 response = requests.request( 216 method="POST", 217 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 218 data=self.build_refresh_request_body(), 219 headers=self.build_refresh_request_headers(), 220 ) 221 222 if not response.ok: 223 # log the response even if the request failed for troubleshooting purposes 224 self._log_response(response) 225 response.raise_for_status() 226 227 response_json = response.json() 228 229 try: 230 # extract the access token and add to secrets to avoid logging the raw value 231 access_key = self._extract_access_token(response_json) 232 if access_key: 233 add_to_secrets(access_key) 234 except ResponseKeysMaxRecurtionReached as e: 235 # could not find the access token in the response, so do nothing 236 pass 237 238 self._log_response(response) 239 240 return response_json 241 except requests.exceptions.RequestException as e: 242 if e.response is not None: 243 if e.response.status_code == 429 or e.response.status_code >= 500: 244 raise DefaultBackoffException( 245 request=e.response.request, 246 response=e.response, 247 failure_type=FailureType.transient_error, 248 ) 249 if self._wrap_refresh_token_exception(e): 250 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 251 raise AirbyteTracedException( 252 internal_message=message, message=message, failure_type=FailureType.config_error 253 ) 254 raise 255 except Exception as e: 256 raise Exception(f"Error while refreshing access token: {e}") from e 257 258 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 259 """ 260 Ensures that the access token is present in the response data. 261 262 This method attempts to extract the access token from the provided response data. 263 If the access token is not found, it raises an exception indicating that the token 264 refresh API response was missing the access token. 265 266 Args: 267 response_data (Mapping[str, Any]): The response data from which to extract the access token. 268 269 Raises: 270 Exception: If the access token is not found in the response data. 271 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 272 """ 273 try: 274 access_key = self._extract_access_token(response_data) 275 if not access_key: 276 raise Exception( 277 f"Token refresh API response was missing access token {self.get_access_token_name()}" 278 ) 279 except ResponseKeysMaxRecurtionReached as e: 280 raise e 281 282 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 283 """ 284 Parse a string or integer token expiration date into a datetime object 285 286 :return: expiration datetime 287 """ 288 if self.token_expiry_is_time_of_expiration: 289 if not self.token_expiry_date_format: 290 raise ValueError( 291 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 292 ) 293 try: 294 return ab_datetime_parse(str(value)) 295 except ValueError as e: 296 raise ValueError(f"Invalid token expiry date format: {e}") 297 else: 298 try: 299 # Only accept numeric values (as int/float/string) when no format specified 300 seconds = int(float(str(value))) 301 return ab_datetime_now() + timedelta(seconds=seconds) 302 except (ValueError, TypeError): 303 raise ValueError( 304 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 305 ) 306 307 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 308 """ 309 Extracts the access token from the given response data. 310 311 Args: 312 response_data (Mapping[str, Any]): The response data from which to extract the access token. 313 314 Returns: 315 str: The extracted access token. 316 """ 317 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 318 319 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 320 """ 321 Extracts the refresh token from the given response data. 322 323 Args: 324 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 325 326 Returns: 327 str: The extracted refresh token. 328 """ 329 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 330 331 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 332 """ 333 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 334 335 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 336 337 Args: 338 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 339 340 Returns: 341 The extracted token_expiry_date or None if not found. 342 """ 343 expires_in = self._find_and_get_value_from_response( 344 response_data, self.get_expires_in_name() 345 ) 346 if expires_in is not None: 347 return self._parse_token_expiration_date(expires_in) 348 349 # expires_in is None 350 existing_expiry_date = self.get_token_expiry_date() 351 if existing_expiry_date and not self.token_has_expired(): 352 return existing_expiry_date 353 354 return self._default_token_expiry_date() 355 356 def _find_and_get_value_from_response( 357 self, 358 response_data: Mapping[str, Any], 359 key_name: str, 360 max_depth: int = 5, 361 current_depth: int = 0, 362 ) -> Any: 363 """ 364 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 365 366 Args: 367 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 368 key_name (str): The key to search for in the response data. 369 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 370 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 371 372 Returns: 373 Any: The value associated with the specified key if found, otherwise None. 374 375 Raises: 376 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 377 """ 378 if current_depth > max_depth: 379 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 380 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 381 raise ResponseKeysMaxRecurtionReached( 382 internal_message=message, message=message, failure_type=FailureType.config_error 383 ) 384 385 if isinstance(response_data, dict): 386 # get from the root level 387 if key_name in response_data: 388 return response_data[key_name] 389 390 # get from the nested object 391 for _, value in response_data.items(): 392 result = self._find_and_get_value_from_response( 393 value, key_name, max_depth, current_depth + 1 394 ) 395 if result is not None: 396 return result 397 398 # get from the nested array object 399 elif isinstance(response_data, list): 400 for item in response_data: 401 result = self._find_and_get_value_from_response( 402 item, key_name, max_depth, current_depth + 1 403 ) 404 if result is not None: 405 return result 406 407 return None 408 409 @property 410 def _message_repository(self) -> Optional[MessageRepository]: 411 """ 412 The implementation can define a message_repository if it wants debugging logs for HTTP requests 413 """ 414 return _NOOP_MESSAGE_REPOSITORY 415 416 def _log_response(self, response: requests.Response) -> None: 417 """ 418 Logs the HTTP response using the message repository if it is available. 419 420 Args: 421 response (requests.Response): The HTTP response to log. 422 """ 423 if self._message_repository: 424 self._message_repository.log_message( 425 Level.DEBUG, 426 lambda: format_http_message( 427 response, 428 "Refresh token", 429 "Obtains access token", 430 self._NO_STREAM_NAME, 431 is_auxiliary=True, 432 type="AUTH", 433 ), 434 ) 435 436 # ---------------- 437 # ABSTR METHODS 438 # ---------------- 439 440 @abstractmethod 441 def get_token_refresh_endpoint(self) -> Optional[str]: 442 """Returns the endpoint to refresh the access token""" 443 444 @abstractmethod 445 def get_client_id_name(self) -> str: 446 """The client id name to authenticate""" 447 448 @abstractmethod 449 def get_client_id(self) -> str: 450 """The client id to authenticate""" 451 452 @abstractmethod 453 def get_client_secret_name(self) -> str: 454 """The client secret name to authenticate""" 455 456 @abstractmethod 457 def get_client_secret(self) -> str: 458 """The client secret to authenticate""" 459 460 @abstractmethod 461 def get_refresh_token_name(self) -> str: 462 """The refresh token name to authenticate""" 463 464 @abstractmethod 465 def get_refresh_token(self) -> Optional[str]: 466 """The token used to refresh the access token when it expires""" 467 468 @abstractmethod 469 def get_scopes(self) -> List[str]: 470 """List of requested scopes""" 471 472 @abstractmethod 473 def get_token_expiry_date(self) -> AirbyteDateTime: 474 """Expiration date of the access token""" 475 476 @abstractmethod 477 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 478 """Setter for access token expiration date""" 479 480 @abstractmethod 481 def get_access_token_name(self) -> str: 482 """Field to extract access token from in the response""" 483 484 @abstractmethod 485 def get_expires_in_name(self) -> str: 486 """Returns the expires_in field name""" 487 488 @abstractmethod 489 def get_refresh_request_body(self) -> Mapping[str, Any]: 490 """Returns the request body to set on the refresh request""" 491 492 @abstractmethod 493 def get_refresh_request_headers(self) -> Mapping[str, Any]: 494 """Returns the request headers to set on the refresh request""" 495 496 @abstractmethod 497 def get_grant_type(self) -> str: 498 """Returns grant_type specified for requesting access_token""" 499 500 @abstractmethod 501 def get_grant_type_name(self) -> str: 502 """Returns grant_type specified name for requesting access_token""" 503 504 @property 505 @abstractmethod 506 def access_token(self) -> str: 507 """Returns the access token""" 508 509 @access_token.setter 510 @abstractmethod 511 def access_token(self, value: str) -> str: 512 """Setter for the access token"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.
45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.
68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"}
HTTP header to set on the requests
89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token
Returns the access token
98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date()
Returns True if the token is expired
102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request 105 106 Override to define additional parameters 107 """ 108 payload: MutableMapping[str, Any] = { 109 self.get_grant_type_name(): self.get_grant_type(), 110 self.get_client_id_name(): self.get_client_id(), 111 self.get_client_secret_name(): self.get_client_secret(), 112 self.get_refresh_token_name(): self.get_refresh_token(), 113 } 114 115 if self.get_scopes(): 116 payload["scopes"] = self.get_scopes() 117 118 if self.get_refresh_request_body(): 119 for key, val in self.get_refresh_request_body().items(): 120 # We defer to existing oauth constructs over custom configured fields 121 if key not in payload: 122 payload[key] = val 123 124 return payload
Returns the request body to set on the refresh request
Override to define additional parameters
126 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 127 """ 128 Returns the request headers to set on the refresh request 129 130 """ 131 headers = self.get_refresh_request_headers() 132 return headers if headers else None
Returns the request headers to set on the refresh request
134 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 135 """ 136 Returns the refresh token and its expiration datetime 137 138 :return: a tuple of (access_token, token_lifespan) 139 """ 140 response_json = self._make_handled_request() 141 self._ensure_access_token_in_response(response_json) 142 143 return ( 144 self._extract_access_token(response_json), 145 self._extract_token_expiry_date(response_json), 146 )
Returns the refresh token and its expiration datetime
Returns
a tuple of (access_token, token_lifespan)
440 @abstractmethod 441 def get_token_refresh_endpoint(self) -> Optional[str]: 442 """Returns the endpoint to refresh the access token"""
Returns the endpoint to refresh the access token
444 @abstractmethod 445 def get_client_id_name(self) -> str: 446 """The client id name to authenticate"""
The client id name to authenticate
452 @abstractmethod 453 def get_client_secret_name(self) -> str: 454 """The client secret name to authenticate"""
The client secret name to authenticate
456 @abstractmethod 457 def get_client_secret(self) -> str: 458 """The client secret to authenticate"""
The client secret to authenticate
460 @abstractmethod 461 def get_refresh_token_name(self) -> str: 462 """The refresh token name to authenticate"""
The refresh token name to authenticate
464 @abstractmethod 465 def get_refresh_token(self) -> Optional[str]: 466 """The token used to refresh the access token when it expires"""
The token used to refresh the access token when it expires
472 @abstractmethod 473 def get_token_expiry_date(self) -> AirbyteDateTime: 474 """Expiration date of the access token"""
Expiration date of the access token
476 @abstractmethod 477 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 478 """Setter for access token expiration date"""
Setter for access token expiration date
480 @abstractmethod 481 def get_access_token_name(self) -> str: 482 """Field to extract access token from in the response"""
Field to extract access token from in the response
484 @abstractmethod 485 def get_expires_in_name(self) -> str: 486 """Returns the expires_in field name"""
Returns the expires_in field name
488 @abstractmethod 489 def get_refresh_request_body(self) -> Mapping[str, Any]: 490 """Returns the request body to set on the refresh request"""
Returns the request body to set on the refresh request
492 @abstractmethod 493 def get_refresh_request_headers(self) -> Mapping[str, Any]: 494 """Returns the request headers to set on the refresh request"""
Returns the request headers to set on the refresh request
496 @abstractmethod 497 def get_grant_type(self) -> str: 498 """Returns grant_type specified for requesting access_token"""
Returns grant_type specified for requesting access_token