airbyte_cdk.sources.connector_state_manager
1# 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 3# 4 5import copy 6from dataclasses import dataclass 7from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union, cast 8 9from airbyte_cdk.models import ( 10 AirbyteMessage, 11 AirbyteStateBlob, 12 AirbyteStateMessage, 13 AirbyteStateType, 14 AirbyteStreamState, 15 StreamDescriptor, 16) 17from airbyte_cdk.models import Type as MessageType 18from airbyte_cdk.models.airbyte_protocol import AirbyteGlobalState, AirbyteStateBlob 19 20 21@dataclass(frozen=True) 22class HashableStreamDescriptor: 23 """ 24 Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and 25 freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests. 26 """ 27 28 name: str 29 namespace: Optional[str] = None 30 31 32class ConnectorStateManager: 33 """ 34 ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common 35 interface. It also provides methods to extract and update state 36 """ 37 38 def __init__(self, state: Optional[List[AirbyteStateMessage]] = None): 39 shared_state, per_stream_states = self._extract_from_state_message(state) 40 41 # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are 42 # designed to checkpoint state independently of one another. API sources should never be emitting a state message where 43 # shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an 44 # error instead and if/when we find one, we will then implement processing of the shared_state value. 45 if shared_state: 46 raise ValueError( 47 "Received a GLOBAL AirbyteStateMessage that contains a shared_state. This library only ever generates per-STREAM " 48 "STATE messages so this was not generated by this connector. This must be an orchestrator or platform error. GLOBAL " 49 "state messages with shared_state will not be processed correctly. " 50 ) 51 self.per_stream_states = per_stream_states 52 53 def get_stream_state( 54 self, stream_name: str, namespace: Optional[str] 55 ) -> MutableMapping[str, Any]: 56 """ 57 Retrieves the state of a given stream based on its descriptor (name + namespace). 58 :param stream_name: Name of the stream being fetched 59 :param namespace: Namespace of the stream being fetched 60 :return: The per-stream state for a stream 61 """ 62 stream_state: AirbyteStateBlob | None = self.per_stream_states.get( 63 HashableStreamDescriptor(name=stream_name, namespace=namespace) 64 ) 65 if stream_state: 66 return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()}) 67 return {} 68 69 def update_state_for_stream( 70 self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any] 71 ) -> None: 72 """ 73 Overwrites the state blob of a specific stream based on the provided stream name and optional namespace 74 :param stream_name: The name of the stream whose state is being updated 75 :param namespace: The namespace of the stream if it exists 76 :param value: A stream state mapping that is being updated for a stream 77 """ 78 stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace) 79 self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value) 80 81 def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage: 82 """ 83 Generates an AirbyteMessage using the current per-stream state of a specified stream 84 :param stream_name: The name of the stream for the message that is being created 85 :param namespace: The namespace of the stream for the message that is being created 86 :return: The Airbyte state message to be emitted by the connector during a sync 87 """ 88 hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace) 89 stream_state = self.per_stream_states.get(hashable_descriptor) or AirbyteStateBlob() 90 91 return AirbyteMessage( 92 type=MessageType.STATE, 93 state=AirbyteStateMessage( 94 type=AirbyteStateType.STREAM, 95 stream=AirbyteStreamState( 96 stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace), 97 stream_state=stream_state, 98 ), 99 ), 100 ) 101 102 @classmethod 103 def _extract_from_state_message( 104 cls, 105 state: Optional[List[AirbyteStateMessage]], 106 ) -> Tuple[ 107 Optional[AirbyteStateBlob], 108 MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]], 109 ]: 110 """ 111 Takes an incoming list of state messages or a global state message and extracts state attributes according to 112 type which can then be assigned to the new state manager being instantiated 113 :param state: The incoming state input 114 :return: A tuple of shared state and per stream state assembled from the incoming state list 115 """ 116 if state is None: 117 return None, {} 118 119 is_global = cls._is_global_state(state) 120 121 if is_global: 122 # We already validate that this is a global state message, not None: 123 global_state = cast(AirbyteGlobalState, state[0].global_) 124 # global_state has shared_state, also not None: 125 shared_state: AirbyteStateBlob = cast( 126 AirbyteStateBlob, copy.deepcopy(global_state.shared_state, {}) 127 ) 128 streams = { 129 HashableStreamDescriptor( 130 name=per_stream_state.stream_descriptor.name, 131 namespace=per_stream_state.stream_descriptor.namespace, 132 ): per_stream_state.stream_state 133 for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state 134 } 135 return shared_state, streams 136 else: 137 streams = { 138 HashableStreamDescriptor( 139 name=per_stream_state.stream.stream_descriptor.name, # type: ignore[union-attr] # stream has stream_descriptor 140 namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor 141 ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state 142 for per_stream_state in state 143 if per_stream_state.type == AirbyteStateType.STREAM 144 and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True 145 } 146 return None, streams 147 148 @staticmethod 149 def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool: 150 return ( 151 isinstance(state, List) 152 and len(state) == 1 153 and isinstance(state[0], AirbyteStateMessage) 154 and state[0].type == AirbyteStateType.GLOBAL 155 ) 156 157 @staticmethod 158 def _is_per_stream_state( 159 state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]], 160 ) -> bool: 161 return isinstance(state, List)
22@dataclass(frozen=True) 23class HashableStreamDescriptor: 24 """ 25 Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and 26 freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests. 27 """ 28 29 name: str 30 namespace: Optional[str] = None
Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests.
33class ConnectorStateManager: 34 """ 35 ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common 36 interface. It also provides methods to extract and update state 37 """ 38 39 def __init__(self, state: Optional[List[AirbyteStateMessage]] = None): 40 shared_state, per_stream_states = self._extract_from_state_message(state) 41 42 # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are 43 # designed to checkpoint state independently of one another. API sources should never be emitting a state message where 44 # shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an 45 # error instead and if/when we find one, we will then implement processing of the shared_state value. 46 if shared_state: 47 raise ValueError( 48 "Received a GLOBAL AirbyteStateMessage that contains a shared_state. This library only ever generates per-STREAM " 49 "STATE messages so this was not generated by this connector. This must be an orchestrator or platform error. GLOBAL " 50 "state messages with shared_state will not be processed correctly. " 51 ) 52 self.per_stream_states = per_stream_states 53 54 def get_stream_state( 55 self, stream_name: str, namespace: Optional[str] 56 ) -> MutableMapping[str, Any]: 57 """ 58 Retrieves the state of a given stream based on its descriptor (name + namespace). 59 :param stream_name: Name of the stream being fetched 60 :param namespace: Namespace of the stream being fetched 61 :return: The per-stream state for a stream 62 """ 63 stream_state: AirbyteStateBlob | None = self.per_stream_states.get( 64 HashableStreamDescriptor(name=stream_name, namespace=namespace) 65 ) 66 if stream_state: 67 return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()}) 68 return {} 69 70 def update_state_for_stream( 71 self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any] 72 ) -> None: 73 """ 74 Overwrites the state blob of a specific stream based on the provided stream name and optional namespace 75 :param stream_name: The name of the stream whose state is being updated 76 :param namespace: The namespace of the stream if it exists 77 :param value: A stream state mapping that is being updated for a stream 78 """ 79 stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace) 80 self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value) 81 82 def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage: 83 """ 84 Generates an AirbyteMessage using the current per-stream state of a specified stream 85 :param stream_name: The name of the stream for the message that is being created 86 :param namespace: The namespace of the stream for the message that is being created 87 :return: The Airbyte state message to be emitted by the connector during a sync 88 """ 89 hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace) 90 stream_state = self.per_stream_states.get(hashable_descriptor) or AirbyteStateBlob() 91 92 return AirbyteMessage( 93 type=MessageType.STATE, 94 state=AirbyteStateMessage( 95 type=AirbyteStateType.STREAM, 96 stream=AirbyteStreamState( 97 stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace), 98 stream_state=stream_state, 99 ), 100 ), 101 ) 102 103 @classmethod 104 def _extract_from_state_message( 105 cls, 106 state: Optional[List[AirbyteStateMessage]], 107 ) -> Tuple[ 108 Optional[AirbyteStateBlob], 109 MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]], 110 ]: 111 """ 112 Takes an incoming list of state messages or a global state message and extracts state attributes according to 113 type which can then be assigned to the new state manager being instantiated 114 :param state: The incoming state input 115 :return: A tuple of shared state and per stream state assembled from the incoming state list 116 """ 117 if state is None: 118 return None, {} 119 120 is_global = cls._is_global_state(state) 121 122 if is_global: 123 # We already validate that this is a global state message, not None: 124 global_state = cast(AirbyteGlobalState, state[0].global_) 125 # global_state has shared_state, also not None: 126 shared_state: AirbyteStateBlob = cast( 127 AirbyteStateBlob, copy.deepcopy(global_state.shared_state, {}) 128 ) 129 streams = { 130 HashableStreamDescriptor( 131 name=per_stream_state.stream_descriptor.name, 132 namespace=per_stream_state.stream_descriptor.namespace, 133 ): per_stream_state.stream_state 134 for per_stream_state in global_state.stream_states # type: ignore[union-attr] # global_state has shared_state 135 } 136 return shared_state, streams 137 else: 138 streams = { 139 HashableStreamDescriptor( 140 name=per_stream_state.stream.stream_descriptor.name, # type: ignore[union-attr] # stream has stream_descriptor 141 namespace=per_stream_state.stream.stream_descriptor.namespace, # type: ignore[union-attr] # stream has stream_descriptor 142 ): per_stream_state.stream.stream_state # type: ignore[union-attr] # stream has stream_state 143 for per_stream_state in state 144 if per_stream_state.type == AirbyteStateType.STREAM 145 and hasattr(per_stream_state, "stream") # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True 146 } 147 return None, streams 148 149 @staticmethod 150 def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool: 151 return ( 152 isinstance(state, List) 153 and len(state) == 1 154 and isinstance(state[0], AirbyteStateMessage) 155 and state[0].type == AirbyteStateType.GLOBAL 156 ) 157 158 @staticmethod 159 def _is_per_stream_state( 160 state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]], 161 ) -> bool: 162 return isinstance(state, List)
ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common interface. It also provides methods to extract and update state
39 def __init__(self, state: Optional[List[AirbyteStateMessage]] = None): 40 shared_state, per_stream_states = self._extract_from_state_message(state) 41 42 # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are 43 # designed to checkpoint state independently of one another. API sources should never be emitting a state message where 44 # shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an 45 # error instead and if/when we find one, we will then implement processing of the shared_state value. 46 if shared_state: 47 raise ValueError( 48 "Received a GLOBAL AirbyteStateMessage that contains a shared_state. This library only ever generates per-STREAM " 49 "STATE messages so this was not generated by this connector. This must be an orchestrator or platform error. GLOBAL " 50 "state messages with shared_state will not be processed correctly. " 51 ) 52 self.per_stream_states = per_stream_states
54 def get_stream_state( 55 self, stream_name: str, namespace: Optional[str] 56 ) -> MutableMapping[str, Any]: 57 """ 58 Retrieves the state of a given stream based on its descriptor (name + namespace). 59 :param stream_name: Name of the stream being fetched 60 :param namespace: Namespace of the stream being fetched 61 :return: The per-stream state for a stream 62 """ 63 stream_state: AirbyteStateBlob | None = self.per_stream_states.get( 64 HashableStreamDescriptor(name=stream_name, namespace=namespace) 65 ) 66 if stream_state: 67 return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()}) 68 return {}
Retrieves the state of a given stream based on its descriptor (name + namespace).
Parameters
- stream_name: Name of the stream being fetched
- namespace: Namespace of the stream being fetched
Returns
The per-stream state for a stream
70 def update_state_for_stream( 71 self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any] 72 ) -> None: 73 """ 74 Overwrites the state blob of a specific stream based on the provided stream name and optional namespace 75 :param stream_name: The name of the stream whose state is being updated 76 :param namespace: The namespace of the stream if it exists 77 :param value: A stream state mapping that is being updated for a stream 78 """ 79 stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace) 80 self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value)
Overwrites the state blob of a specific stream based on the provided stream name and optional namespace
Parameters
- stream_name: The name of the stream whose state is being updated
- namespace: The namespace of the stream if it exists
- value: A stream state mapping that is being updated for a stream
82 def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage: 83 """ 84 Generates an AirbyteMessage using the current per-stream state of a specified stream 85 :param stream_name: The name of the stream for the message that is being created 86 :param namespace: The namespace of the stream for the message that is being created 87 :return: The Airbyte state message to be emitted by the connector during a sync 88 """ 89 hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace) 90 stream_state = self.per_stream_states.get(hashable_descriptor) or AirbyteStateBlob() 91 92 return AirbyteMessage( 93 type=MessageType.STATE, 94 state=AirbyteStateMessage( 95 type=AirbyteStateType.STREAM, 96 stream=AirbyteStreamState( 97 stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace), 98 stream_state=stream_state, 99 ), 100 ), 101 )
Generates an AirbyteMessage using the current per-stream state of a specified stream
Parameters
- stream_name: The name of the stream for the message that is being created
- namespace: The namespace of the stream for the message that is being created
Returns
The Airbyte state message to be emitted by the connector during a sync