airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth
1# 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 3# 4 5import logging 6from abc import abstractmethod 7from datetime import timedelta 8from json import JSONDecodeError 9from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union 10 11import backoff 12import requests 13from requests.auth import AuthBase 14 15from airbyte_cdk.models import FailureType, Level 16from airbyte_cdk.sources.http_logger import format_http_message 17from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository 18from airbyte_cdk.utils import AirbyteTracedException 19from airbyte_cdk.utils.airbyte_secrets_utils import add_to_secrets 20from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse 21 22from ..exceptions import DefaultBackoffException 23 24logger = logging.getLogger("airbyte") 25_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() 26 27 28class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 29 """ 30 Raised when the max level of recursion is reached, when trying to 31 find-and-get the target key, during the `_make_handled_request` 32 """ 33 34 35class AbstractOauth2Authenticator(AuthBase): 36 """ 37 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 38 is designed to generically perform the refresh flow without regard to how config fields are get/set by 39 delegating that behavior to the classes implementing the interface. 40 """ 41 42 _NO_STREAM_NAME = None 43 44 def __init__( 45 self, 46 refresh_token_error_status_codes: Tuple[int, ...] = (), 47 refresh_token_error_key: str = "", 48 refresh_token_error_values: Tuple[str, ...] = (), 49 ) -> None: 50 """ 51 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 52 then http errors with such params will be wrapped in AirbyteTracedException. 53 """ 54 self._refresh_token_error_status_codes = refresh_token_error_status_codes 55 self._refresh_token_error_key = refresh_token_error_key 56 self._refresh_token_error_values = refresh_token_error_values 57 58 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 59 """Attach the HTTP headers required to authenticate on the HTTP request""" 60 request.headers.update(self.get_auth_header()) 61 return request 62 63 @property 64 def _is_access_token_flow(self) -> bool: 65 return self.get_token_refresh_endpoint() is None and self.access_token is not None 66 67 @property 68 def token_expiry_is_time_of_expiration(self) -> bool: 69 """ 70 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 71 """ 72 73 return False 74 75 @property 76 def token_expiry_date_format(self) -> Optional[str]: 77 """ 78 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 79 """ 80 81 return None 82 83 def get_auth_header(self) -> Mapping[str, Any]: 84 """HTTP header to set on the requests""" 85 token = self.access_token if self._is_access_token_flow else self.get_access_token() 86 return {"Authorization": f"Bearer {token}"} 87 88 def get_access_token(self) -> str: 89 """Returns the access token""" 90 if self.token_has_expired(): 91 token, expires_in = self.refresh_access_token() 92 self.access_token = token 93 self.set_token_expiry_date(expires_in) 94 95 return self.access_token 96 97 def token_has_expired(self) -> bool: 98 """Returns True if the token is expired""" 99 return ab_datetime_now() > self.get_token_expiry_date() 100 101 def build_refresh_request_body(self) -> Mapping[str, Any]: 102 """ 103 Returns the request body to set on the refresh request. 104 105 Override to define additional parameters. 106 107 Client credentials (client_id and client_secret) are excluded from the body when 108 refresh_request_headers contains an Authorization header (e.g., Basic auth). 109 This is required by OAuth providers like Gong that expect credentials ONLY in the 110 Authorization header and reject requests that include them in both places. 111 """ 112 # Check if credentials are being sent via Authorization header 113 headers = self.get_refresh_request_headers() 114 credentials_in_header = headers and "Authorization" in headers 115 116 # Only include client credentials in body if not already in header 117 include_client_credentials = not credentials_in_header 118 119 payload: MutableMapping[str, Any] = { 120 self.get_grant_type_name(): self.get_grant_type(), 121 } 122 123 # Only include client credentials in body if configured to do so and not in header 124 if include_client_credentials: 125 payload[self.get_client_id_name()] = self.get_client_id() 126 payload[self.get_client_secret_name()] = self.get_client_secret() 127 128 payload[self.get_refresh_token_name()] = self.get_refresh_token() 129 130 if self.get_scopes(): 131 payload["scopes"] = self.get_scopes() 132 133 if self.get_refresh_request_body(): 134 for key, val in self.get_refresh_request_body().items(): 135 # We defer to existing oauth constructs over custom configured fields 136 if key not in payload: 137 payload[key] = val 138 139 return payload 140 141 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 142 """ 143 Returns the request headers to set on the refresh request 144 145 """ 146 headers = self.get_refresh_request_headers() 147 return headers if headers else None 148 149 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 150 """ 151 Returns the refresh token and its expiration datetime 152 153 :return: a tuple of (access_token, token_lifespan) 154 """ 155 response_json = self._make_handled_request() 156 self._ensure_access_token_in_response(response_json) 157 158 return ( 159 self._extract_access_token(response_json), 160 self._extract_token_expiry_date(response_json), 161 ) 162 163 # ---------------- 164 # PRIVATE METHODS 165 # ---------------- 166 167 def _default_token_expiry_date(self) -> AirbyteDateTime: 168 """ 169 Returns the default token expiry date 170 """ 171 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 172 default_token_expiry_duration_hours = 1 # 1 hour 173 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 174 175 def _wrap_refresh_token_exception( 176 self, exception: requests.exceptions.RequestException 177 ) -> bool: 178 """ 179 Wraps and handles exceptions that occur during the refresh token process. 180 181 This method checks if the provided exception is related to a refresh token error 182 by examining the response status code and specific error content. 183 184 Args: 185 exception (requests.exceptions.RequestException): The exception raised during the request. 186 187 Returns: 188 bool: True if the exception is related to a refresh token error, False otherwise. 189 """ 190 try: 191 if exception.response is not None: 192 exception_content = exception.response.json() 193 else: 194 return False 195 except JSONDecodeError: 196 return False 197 return ( 198 exception.response.status_code in self._refresh_token_error_status_codes 199 and exception_content.get(self._refresh_token_error_key) 200 in self._refresh_token_error_values 201 ) 202 203 @backoff.on_exception( 204 backoff.expo, 205 DefaultBackoffException, 206 on_backoff=lambda details: logger.info( 207 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 208 ), 209 max_time=300, 210 ) 211 def _make_handled_request(self) -> Any: 212 """ 213 Makes a handled HTTP request to refresh an OAuth token. 214 215 This method sends a POST request to the token refresh endpoint with the necessary 216 headers and body to obtain a new access token. It handles various exceptions that 217 may occur during the request and logs the response for troubleshooting purposes. 218 219 Returns: 220 Mapping[str, Any]: The JSON response from the token refresh endpoint. 221 222 Raises: 223 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 224 or any 5xx server error. 225 AirbyteTracedException: If the refresh token is invalid or expired, prompting 226 re-authentication. 227 Exception: For any other exceptions that occur during the request. 228 """ 229 try: 230 response = requests.request( 231 method="POST", 232 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 233 data=self.build_refresh_request_body(), 234 headers=self.build_refresh_request_headers(), 235 ) 236 237 if not response.ok: 238 # log the response even if the request failed for troubleshooting purposes 239 self._log_response(response) 240 response.raise_for_status() 241 242 response_json = response.json() 243 244 try: 245 # extract the access token and add to secrets to avoid logging the raw value 246 access_key = self._extract_access_token(response_json) 247 if access_key: 248 add_to_secrets(access_key) 249 except ResponseKeysMaxRecurtionReached as e: 250 # could not find the access token in the response, so do nothing 251 pass 252 253 self._log_response(response) 254 255 return response_json 256 except requests.exceptions.RequestException as e: 257 if e.response is not None: 258 if e.response.status_code == 429 or e.response.status_code >= 500: 259 raise DefaultBackoffException( 260 request=e.response.request, 261 response=e.response, 262 failure_type=FailureType.transient_error, 263 ) 264 if self._wrap_refresh_token_exception(e): 265 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 266 raise AirbyteTracedException( 267 internal_message=message, message=message, failure_type=FailureType.config_error 268 ) 269 raise 270 except Exception as e: 271 raise Exception(f"Error while refreshing access token: {e}") from e 272 273 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 274 """ 275 Ensures that the access token is present in the response data. 276 277 This method attempts to extract the access token from the provided response data. 278 If the access token is not found, it raises an exception indicating that the token 279 refresh API response was missing the access token. 280 281 Args: 282 response_data (Mapping[str, Any]): The response data from which to extract the access token. 283 284 Raises: 285 Exception: If the access token is not found in the response data. 286 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 287 """ 288 try: 289 access_key = self._extract_access_token(response_data) 290 if not access_key: 291 raise Exception( 292 f"Token refresh API response was missing access token {self.get_access_token_name()}" 293 ) 294 except ResponseKeysMaxRecurtionReached as e: 295 raise e 296 297 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 298 """ 299 Parse a string or integer token expiration date into a datetime object 300 301 :return: expiration datetime 302 """ 303 if self.token_expiry_is_time_of_expiration: 304 if not self.token_expiry_date_format: 305 raise ValueError( 306 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 307 ) 308 try: 309 return ab_datetime_parse(str(value)) 310 except ValueError as e: 311 raise ValueError(f"Invalid token expiry date format: {e}") 312 else: 313 try: 314 # Only accept numeric values (as int/float/string) when no format specified 315 seconds = int(float(str(value))) 316 return ab_datetime_now() + timedelta(seconds=seconds) 317 except (ValueError, TypeError): 318 raise ValueError( 319 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 320 ) 321 322 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 323 """ 324 Extracts the access token from the given response data. 325 326 Args: 327 response_data (Mapping[str, Any]): The response data from which to extract the access token. 328 329 Returns: 330 str: The extracted access token. 331 """ 332 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 333 334 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 335 """ 336 Extracts the refresh token from the given response data. 337 338 Args: 339 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 340 341 Returns: 342 str: The extracted refresh token. 343 """ 344 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 345 346 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 347 """ 348 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 349 350 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 351 352 Args: 353 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 354 355 Returns: 356 The extracted token_expiry_date or None if not found. 357 """ 358 expires_in = self._find_and_get_value_from_response( 359 response_data, self.get_expires_in_name() 360 ) 361 if expires_in is not None: 362 return self._parse_token_expiration_date(expires_in) 363 364 # expires_in is None 365 existing_expiry_date = self.get_token_expiry_date() 366 if existing_expiry_date and not self.token_has_expired(): 367 return existing_expiry_date 368 369 return self._default_token_expiry_date() 370 371 def _find_and_get_value_from_response( 372 self, 373 response_data: Mapping[str, Any], 374 key_name: str, 375 max_depth: int = 5, 376 current_depth: int = 0, 377 ) -> Any: 378 """ 379 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 380 381 Args: 382 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 383 key_name (str): The key to search for in the response data. 384 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 385 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 386 387 Returns: 388 Any: The value associated with the specified key if found, otherwise None. 389 390 Raises: 391 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 392 """ 393 if current_depth > max_depth: 394 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 395 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 396 raise ResponseKeysMaxRecurtionReached( 397 internal_message=message, message=message, failure_type=FailureType.config_error 398 ) 399 400 if isinstance(response_data, dict): 401 # get from the root level 402 if key_name in response_data: 403 return response_data[key_name] 404 405 # get from the nested object 406 for _, value in response_data.items(): 407 result = self._find_and_get_value_from_response( 408 value, key_name, max_depth, current_depth + 1 409 ) 410 if result is not None: 411 return result 412 413 # get from the nested array object 414 elif isinstance(response_data, list): 415 for item in response_data: 416 result = self._find_and_get_value_from_response( 417 item, key_name, max_depth, current_depth + 1 418 ) 419 if result is not None: 420 return result 421 422 return None 423 424 @property 425 def _message_repository(self) -> Optional[MessageRepository]: 426 """ 427 The implementation can define a message_repository if it wants debugging logs for HTTP requests 428 """ 429 return _NOOP_MESSAGE_REPOSITORY 430 431 def _log_response(self, response: requests.Response) -> None: 432 """ 433 Logs the HTTP response using the message repository if it is available. 434 435 Args: 436 response (requests.Response): The HTTP response to log. 437 """ 438 if self._message_repository: 439 self._message_repository.log_message( 440 Level.DEBUG, 441 lambda: format_http_message( 442 response, 443 "Refresh token", 444 "Obtains access token", 445 self._NO_STREAM_NAME, 446 is_auxiliary=True, 447 type="AUTH", 448 ), 449 ) 450 451 # ---------------- 452 # ABSTR METHODS 453 # ---------------- 454 455 @abstractmethod 456 def get_token_refresh_endpoint(self) -> Optional[str]: 457 """Returns the endpoint to refresh the access token""" 458 459 @abstractmethod 460 def get_client_id_name(self) -> str: 461 """The client id name to authenticate""" 462 463 @abstractmethod 464 def get_client_id(self) -> str: 465 """The client id to authenticate""" 466 467 @abstractmethod 468 def get_client_secret_name(self) -> str: 469 """The client secret name to authenticate""" 470 471 @abstractmethod 472 def get_client_secret(self) -> str: 473 """The client secret to authenticate""" 474 475 @abstractmethod 476 def get_refresh_token_name(self) -> str: 477 """The refresh token name to authenticate""" 478 479 @abstractmethod 480 def get_refresh_token(self) -> Optional[str]: 481 """The token used to refresh the access token when it expires""" 482 483 @abstractmethod 484 def get_scopes(self) -> List[str]: 485 """List of requested scopes""" 486 487 @abstractmethod 488 def get_token_expiry_date(self) -> AirbyteDateTime: 489 """Expiration date of the access token""" 490 491 @abstractmethod 492 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 493 """Setter for access token expiration date""" 494 495 @abstractmethod 496 def get_access_token_name(self) -> str: 497 """Field to extract access token from in the response""" 498 499 @abstractmethod 500 def get_expires_in_name(self) -> str: 501 """Returns the expires_in field name""" 502 503 @abstractmethod 504 def get_refresh_request_body(self) -> Mapping[str, Any]: 505 """Returns the request body to set on the refresh request""" 506 507 @abstractmethod 508 def get_refresh_request_headers(self) -> Mapping[str, Any]: 509 """Returns the request headers to set on the refresh request""" 510 511 @abstractmethod 512 def get_grant_type(self) -> str: 513 """Returns grant_type specified for requesting access_token""" 514 515 @abstractmethod 516 def get_grant_type_name(self) -> str: 517 """Returns grant_type specified name for requesting access_token""" 518 519 @property 520 @abstractmethod 521 def access_token(self) -> str: 522 """Returns the access token""" 523 524 @access_token.setter 525 @abstractmethod 526 def access_token(self, value: str) -> str: 527 """Setter for the access token"""
29class ResponseKeysMaxRecurtionReached(AirbyteTracedException): 30 """ 31 Raised when the max level of recursion is reached, when trying to 32 find-and-get the target key, during the `_make_handled_request` 33 """
Raised when the max level of recursion is reached, when trying to
find-and-get the target key, during the _make_handled_request
36class AbstractOauth2Authenticator(AuthBase): 37 """ 38 Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator 39 is designed to generically perform the refresh flow without regard to how config fields are get/set by 40 delegating that behavior to the classes implementing the interface. 41 """ 42 43 _NO_STREAM_NAME = None 44 45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values 58 59 def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: 60 """Attach the HTTP headers required to authenticate on the HTTP request""" 61 request.headers.update(self.get_auth_header()) 62 return request 63 64 @property 65 def _is_access_token_flow(self) -> bool: 66 return self.get_token_refresh_endpoint() is None and self.access_token is not None 67 68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False 75 76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None 83 84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"} 88 89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token 97 98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date() 101 102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request. 105 106 Override to define additional parameters. 107 108 Client credentials (client_id and client_secret) are excluded from the body when 109 refresh_request_headers contains an Authorization header (e.g., Basic auth). 110 This is required by OAuth providers like Gong that expect credentials ONLY in the 111 Authorization header and reject requests that include them in both places. 112 """ 113 # Check if credentials are being sent via Authorization header 114 headers = self.get_refresh_request_headers() 115 credentials_in_header = headers and "Authorization" in headers 116 117 # Only include client credentials in body if not already in header 118 include_client_credentials = not credentials_in_header 119 120 payload: MutableMapping[str, Any] = { 121 self.get_grant_type_name(): self.get_grant_type(), 122 } 123 124 # Only include client credentials in body if configured to do so and not in header 125 if include_client_credentials: 126 payload[self.get_client_id_name()] = self.get_client_id() 127 payload[self.get_client_secret_name()] = self.get_client_secret() 128 129 payload[self.get_refresh_token_name()] = self.get_refresh_token() 130 131 if self.get_scopes(): 132 payload["scopes"] = self.get_scopes() 133 134 if self.get_refresh_request_body(): 135 for key, val in self.get_refresh_request_body().items(): 136 # We defer to existing oauth constructs over custom configured fields 137 if key not in payload: 138 payload[key] = val 139 140 return payload 141 142 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 143 """ 144 Returns the request headers to set on the refresh request 145 146 """ 147 headers = self.get_refresh_request_headers() 148 return headers if headers else None 149 150 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 151 """ 152 Returns the refresh token and its expiration datetime 153 154 :return: a tuple of (access_token, token_lifespan) 155 """ 156 response_json = self._make_handled_request() 157 self._ensure_access_token_in_response(response_json) 158 159 return ( 160 self._extract_access_token(response_json), 161 self._extract_token_expiry_date(response_json), 162 ) 163 164 # ---------------- 165 # PRIVATE METHODS 166 # ---------------- 167 168 def _default_token_expiry_date(self) -> AirbyteDateTime: 169 """ 170 Returns the default token expiry date 171 """ 172 # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration 173 default_token_expiry_duration_hours = 1 # 1 hour 174 return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours) 175 176 def _wrap_refresh_token_exception( 177 self, exception: requests.exceptions.RequestException 178 ) -> bool: 179 """ 180 Wraps and handles exceptions that occur during the refresh token process. 181 182 This method checks if the provided exception is related to a refresh token error 183 by examining the response status code and specific error content. 184 185 Args: 186 exception (requests.exceptions.RequestException): The exception raised during the request. 187 188 Returns: 189 bool: True if the exception is related to a refresh token error, False otherwise. 190 """ 191 try: 192 if exception.response is not None: 193 exception_content = exception.response.json() 194 else: 195 return False 196 except JSONDecodeError: 197 return False 198 return ( 199 exception.response.status_code in self._refresh_token_error_status_codes 200 and exception_content.get(self._refresh_token_error_key) 201 in self._refresh_token_error_values 202 ) 203 204 @backoff.on_exception( 205 backoff.expo, 206 DefaultBackoffException, 207 on_backoff=lambda details: logger.info( 208 f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..." 209 ), 210 max_time=300, 211 ) 212 def _make_handled_request(self) -> Any: 213 """ 214 Makes a handled HTTP request to refresh an OAuth token. 215 216 This method sends a POST request to the token refresh endpoint with the necessary 217 headers and body to obtain a new access token. It handles various exceptions that 218 may occur during the request and logs the response for troubleshooting purposes. 219 220 Returns: 221 Mapping[str, Any]: The JSON response from the token refresh endpoint. 222 223 Raises: 224 DefaultBackoffException: If the response status code is 429 (Too Many Requests) 225 or any 5xx server error. 226 AirbyteTracedException: If the refresh token is invalid or expired, prompting 227 re-authentication. 228 Exception: For any other exceptions that occur during the request. 229 """ 230 try: 231 response = requests.request( 232 method="POST", 233 url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected. 234 data=self.build_refresh_request_body(), 235 headers=self.build_refresh_request_headers(), 236 ) 237 238 if not response.ok: 239 # log the response even if the request failed for troubleshooting purposes 240 self._log_response(response) 241 response.raise_for_status() 242 243 response_json = response.json() 244 245 try: 246 # extract the access token and add to secrets to avoid logging the raw value 247 access_key = self._extract_access_token(response_json) 248 if access_key: 249 add_to_secrets(access_key) 250 except ResponseKeysMaxRecurtionReached as e: 251 # could not find the access token in the response, so do nothing 252 pass 253 254 self._log_response(response) 255 256 return response_json 257 except requests.exceptions.RequestException as e: 258 if e.response is not None: 259 if e.response.status_code == 429 or e.response.status_code >= 500: 260 raise DefaultBackoffException( 261 request=e.response.request, 262 response=e.response, 263 failure_type=FailureType.transient_error, 264 ) 265 if self._wrap_refresh_token_exception(e): 266 message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings." 267 raise AirbyteTracedException( 268 internal_message=message, message=message, failure_type=FailureType.config_error 269 ) 270 raise 271 except Exception as e: 272 raise Exception(f"Error while refreshing access token: {e}") from e 273 274 def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: 275 """ 276 Ensures that the access token is present in the response data. 277 278 This method attempts to extract the access token from the provided response data. 279 If the access token is not found, it raises an exception indicating that the token 280 refresh API response was missing the access token. 281 282 Args: 283 response_data (Mapping[str, Any]): The response data from which to extract the access token. 284 285 Raises: 286 Exception: If the access token is not found in the response data. 287 ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. 288 """ 289 try: 290 access_key = self._extract_access_token(response_data) 291 if not access_key: 292 raise Exception( 293 f"Token refresh API response was missing access token {self.get_access_token_name()}" 294 ) 295 except ResponseKeysMaxRecurtionReached as e: 296 raise e 297 298 def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: 299 """ 300 Parse a string or integer token expiration date into a datetime object 301 302 :return: expiration datetime 303 """ 304 if self.token_expiry_is_time_of_expiration: 305 if not self.token_expiry_date_format: 306 raise ValueError( 307 f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required." 308 ) 309 try: 310 return ab_datetime_parse(str(value)) 311 except ValueError as e: 312 raise ValueError(f"Invalid token expiry date format: {e}") 313 else: 314 try: 315 # Only accept numeric values (as int/float/string) when no format specified 316 seconds = int(float(str(value))) 317 return ab_datetime_now() + timedelta(seconds=seconds) 318 except (ValueError, TypeError): 319 raise ValueError( 320 f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." 321 ) 322 323 def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: 324 """ 325 Extracts the access token from the given response data. 326 327 Args: 328 response_data (Mapping[str, Any]): The response data from which to extract the access token. 329 330 Returns: 331 str: The extracted access token. 332 """ 333 return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) 334 335 def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: 336 """ 337 Extracts the refresh token from the given response data. 338 339 Args: 340 response_data (Mapping[str, Any]): The response data from which to extract the refresh token. 341 342 Returns: 343 str: The extracted refresh token. 344 """ 345 return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) 346 347 def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime: 348 """ 349 Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. 350 351 If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date. 352 353 Args: 354 response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. 355 356 Returns: 357 The extracted token_expiry_date or None if not found. 358 """ 359 expires_in = self._find_and_get_value_from_response( 360 response_data, self.get_expires_in_name() 361 ) 362 if expires_in is not None: 363 return self._parse_token_expiration_date(expires_in) 364 365 # expires_in is None 366 existing_expiry_date = self.get_token_expiry_date() 367 if existing_expiry_date and not self.token_has_expired(): 368 return existing_expiry_date 369 370 return self._default_token_expiry_date() 371 372 def _find_and_get_value_from_response( 373 self, 374 response_data: Mapping[str, Any], 375 key_name: str, 376 max_depth: int = 5, 377 current_depth: int = 0, 378 ) -> Any: 379 """ 380 Recursively searches for a specified key in a nested dictionary or list and returns its value if found. 381 382 Args: 383 response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. 384 key_name (str): The key to search for in the response data. 385 max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. 386 current_depth (int, optional): The current depth of the recursion. Defaults to 0. 387 388 Returns: 389 Any: The value associated with the specified key if found, otherwise None. 390 391 Raises: 392 AirbyteTracedException: If the maximum recursion depth is reached without finding the key. 393 """ 394 if current_depth > max_depth: 395 # this is needed to avoid an inf loop, possible with a very deep nesting observed. 396 message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response." 397 raise ResponseKeysMaxRecurtionReached( 398 internal_message=message, message=message, failure_type=FailureType.config_error 399 ) 400 401 if isinstance(response_data, dict): 402 # get from the root level 403 if key_name in response_data: 404 return response_data[key_name] 405 406 # get from the nested object 407 for _, value in response_data.items(): 408 result = self._find_and_get_value_from_response( 409 value, key_name, max_depth, current_depth + 1 410 ) 411 if result is not None: 412 return result 413 414 # get from the nested array object 415 elif isinstance(response_data, list): 416 for item in response_data: 417 result = self._find_and_get_value_from_response( 418 item, key_name, max_depth, current_depth + 1 419 ) 420 if result is not None: 421 return result 422 423 return None 424 425 @property 426 def _message_repository(self) -> Optional[MessageRepository]: 427 """ 428 The implementation can define a message_repository if it wants debugging logs for HTTP requests 429 """ 430 return _NOOP_MESSAGE_REPOSITORY 431 432 def _log_response(self, response: requests.Response) -> None: 433 """ 434 Logs the HTTP response using the message repository if it is available. 435 436 Args: 437 response (requests.Response): The HTTP response to log. 438 """ 439 if self._message_repository: 440 self._message_repository.log_message( 441 Level.DEBUG, 442 lambda: format_http_message( 443 response, 444 "Refresh token", 445 "Obtains access token", 446 self._NO_STREAM_NAME, 447 is_auxiliary=True, 448 type="AUTH", 449 ), 450 ) 451 452 # ---------------- 453 # ABSTR METHODS 454 # ---------------- 455 456 @abstractmethod 457 def get_token_refresh_endpoint(self) -> Optional[str]: 458 """Returns the endpoint to refresh the access token""" 459 460 @abstractmethod 461 def get_client_id_name(self) -> str: 462 """The client id name to authenticate""" 463 464 @abstractmethod 465 def get_client_id(self) -> str: 466 """The client id to authenticate""" 467 468 @abstractmethod 469 def get_client_secret_name(self) -> str: 470 """The client secret name to authenticate""" 471 472 @abstractmethod 473 def get_client_secret(self) -> str: 474 """The client secret to authenticate""" 475 476 @abstractmethod 477 def get_refresh_token_name(self) -> str: 478 """The refresh token name to authenticate""" 479 480 @abstractmethod 481 def get_refresh_token(self) -> Optional[str]: 482 """The token used to refresh the access token when it expires""" 483 484 @abstractmethod 485 def get_scopes(self) -> List[str]: 486 """List of requested scopes""" 487 488 @abstractmethod 489 def get_token_expiry_date(self) -> AirbyteDateTime: 490 """Expiration date of the access token""" 491 492 @abstractmethod 493 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 494 """Setter for access token expiration date""" 495 496 @abstractmethod 497 def get_access_token_name(self) -> str: 498 """Field to extract access token from in the response""" 499 500 @abstractmethod 501 def get_expires_in_name(self) -> str: 502 """Returns the expires_in field name""" 503 504 @abstractmethod 505 def get_refresh_request_body(self) -> Mapping[str, Any]: 506 """Returns the request body to set on the refresh request""" 507 508 @abstractmethod 509 def get_refresh_request_headers(self) -> Mapping[str, Any]: 510 """Returns the request headers to set on the refresh request""" 511 512 @abstractmethod 513 def get_grant_type(self) -> str: 514 """Returns grant_type specified for requesting access_token""" 515 516 @abstractmethod 517 def get_grant_type_name(self) -> str: 518 """Returns grant_type specified name for requesting access_token""" 519 520 @property 521 @abstractmethod 522 def access_token(self) -> str: 523 """Returns the access token""" 524 525 @access_token.setter 526 @abstractmethod 527 def access_token(self, value: str) -> str: 528 """Setter for the access token"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator is designed to generically perform the refresh flow without regard to how config fields are get/set by delegating that behavior to the classes implementing the interface.
45 def __init__( 46 self, 47 refresh_token_error_status_codes: Tuple[int, ...] = (), 48 refresh_token_error_key: str = "", 49 refresh_token_error_values: Tuple[str, ...] = (), 50 ) -> None: 51 """ 52 If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, 53 then http errors with such params will be wrapped in AirbyteTracedException. 54 """ 55 self._refresh_token_error_status_codes = refresh_token_error_status_codes 56 self._refresh_token_error_key = refresh_token_error_key 57 self._refresh_token_error_values = refresh_token_error_values
If all of refresh_token_error_status_codes, refresh_token_error_key, and refresh_token_error_values are set, then http errors with such params will be wrapped in AirbyteTracedException.
68 @property 69 def token_expiry_is_time_of_expiration(self) -> bool: 70 """ 71 Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. 72 """ 73 74 return False
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
76 @property 77 def token_expiry_date_format(self) -> Optional[str]: 78 """ 79 Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires 80 """ 81 82 return None
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
84 def get_auth_header(self) -> Mapping[str, Any]: 85 """HTTP header to set on the requests""" 86 token = self.access_token if self._is_access_token_flow else self.get_access_token() 87 return {"Authorization": f"Bearer {token}"}
HTTP header to set on the requests
89 def get_access_token(self) -> str: 90 """Returns the access token""" 91 if self.token_has_expired(): 92 token, expires_in = self.refresh_access_token() 93 self.access_token = token 94 self.set_token_expiry_date(expires_in) 95 96 return self.access_token
Returns the access token
98 def token_has_expired(self) -> bool: 99 """Returns True if the token is expired""" 100 return ab_datetime_now() > self.get_token_expiry_date()
Returns True if the token is expired
102 def build_refresh_request_body(self) -> Mapping[str, Any]: 103 """ 104 Returns the request body to set on the refresh request. 105 106 Override to define additional parameters. 107 108 Client credentials (client_id and client_secret) are excluded from the body when 109 refresh_request_headers contains an Authorization header (e.g., Basic auth). 110 This is required by OAuth providers like Gong that expect credentials ONLY in the 111 Authorization header and reject requests that include them in both places. 112 """ 113 # Check if credentials are being sent via Authorization header 114 headers = self.get_refresh_request_headers() 115 credentials_in_header = headers and "Authorization" in headers 116 117 # Only include client credentials in body if not already in header 118 include_client_credentials = not credentials_in_header 119 120 payload: MutableMapping[str, Any] = { 121 self.get_grant_type_name(): self.get_grant_type(), 122 } 123 124 # Only include client credentials in body if configured to do so and not in header 125 if include_client_credentials: 126 payload[self.get_client_id_name()] = self.get_client_id() 127 payload[self.get_client_secret_name()] = self.get_client_secret() 128 129 payload[self.get_refresh_token_name()] = self.get_refresh_token() 130 131 if self.get_scopes(): 132 payload["scopes"] = self.get_scopes() 133 134 if self.get_refresh_request_body(): 135 for key, val in self.get_refresh_request_body().items(): 136 # We defer to existing oauth constructs over custom configured fields 137 if key not in payload: 138 payload[key] = val 139 140 return payload
Returns the request body to set on the refresh request.
Override to define additional parameters.
Client credentials (client_id and client_secret) are excluded from the body when refresh_request_headers contains an Authorization header (e.g., Basic auth). This is required by OAuth providers like Gong that expect credentials ONLY in the Authorization header and reject requests that include them in both places.
142 def build_refresh_request_headers(self) -> Mapping[str, Any] | None: 143 """ 144 Returns the request headers to set on the refresh request 145 146 """ 147 headers = self.get_refresh_request_headers() 148 return headers if headers else None
Returns the request headers to set on the refresh request
150 def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]: 151 """ 152 Returns the refresh token and its expiration datetime 153 154 :return: a tuple of (access_token, token_lifespan) 155 """ 156 response_json = self._make_handled_request() 157 self._ensure_access_token_in_response(response_json) 158 159 return ( 160 self._extract_access_token(response_json), 161 self._extract_token_expiry_date(response_json), 162 )
Returns the refresh token and its expiration datetime
Returns
a tuple of (access_token, token_lifespan)
456 @abstractmethod 457 def get_token_refresh_endpoint(self) -> Optional[str]: 458 """Returns the endpoint to refresh the access token"""
Returns the endpoint to refresh the access token
460 @abstractmethod 461 def get_client_id_name(self) -> str: 462 """The client id name to authenticate"""
The client id name to authenticate
468 @abstractmethod 469 def get_client_secret_name(self) -> str: 470 """The client secret name to authenticate"""
The client secret name to authenticate
472 @abstractmethod 473 def get_client_secret(self) -> str: 474 """The client secret to authenticate"""
The client secret to authenticate
476 @abstractmethod 477 def get_refresh_token_name(self) -> str: 478 """The refresh token name to authenticate"""
The refresh token name to authenticate
480 @abstractmethod 481 def get_refresh_token(self) -> Optional[str]: 482 """The token used to refresh the access token when it expires"""
The token used to refresh the access token when it expires
488 @abstractmethod 489 def get_token_expiry_date(self) -> AirbyteDateTime: 490 """Expiration date of the access token"""
Expiration date of the access token
492 @abstractmethod 493 def set_token_expiry_date(self, value: AirbyteDateTime) -> None: 494 """Setter for access token expiration date"""
Setter for access token expiration date
496 @abstractmethod 497 def get_access_token_name(self) -> str: 498 """Field to extract access token from in the response"""
Field to extract access token from in the response
500 @abstractmethod 501 def get_expires_in_name(self) -> str: 502 """Returns the expires_in field name"""
Returns the expires_in field name
504 @abstractmethod 505 def get_refresh_request_body(self) -> Mapping[str, Any]: 506 """Returns the request body to set on the refresh request"""
Returns the request body to set on the refresh request
508 @abstractmethod 509 def get_refresh_request_headers(self) -> Mapping[str, Any]: 510 """Returns the request headers to set on the refresh request"""
Returns the request headers to set on the refresh request
512 @abstractmethod 513 def get_grant_type(self) -> str: 514 """Returns grant_type specified for requesting access_token"""
Returns grant_type specified for requesting access_token