airbyte_cdk.sources.connector_state_manager

  1#
  2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
  3#
  4
  5import copy
  6from dataclasses import dataclass
  7from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union, cast
  8
  9from airbyte_cdk.models import (
 10    AirbyteMessage,
 11    AirbyteStateBlob,
 12    AirbyteStateMessage,
 13    AirbyteStateType,
 14    AirbyteStreamState,
 15    StreamDescriptor,
 16)
 17from airbyte_cdk.models import Type as MessageType
 18from airbyte_cdk.models.airbyte_protocol import AirbyteGlobalState, AirbyteStateBlob
 19
 20
 21@dataclass(frozen=True)
 22class HashableStreamDescriptor:
 23    """
 24    Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and
 25    freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests.
 26    """
 27
 28    name: str
 29    namespace: Optional[str] = None
 30
 31
 32class ConnectorStateManager:
 33    """
 34    ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common
 35    interface. It also provides methods to extract and update state
 36    """
 37
 38    def __init__(self, state: Optional[List[AirbyteStateMessage]] = None):
 39        shared_state, per_stream_states = self._extract_from_state_message(state)
 40
 41        # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are
 42        # designed to checkpoint state independently of one another. API sources should never be emitting a state message where
 43        # shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an
 44        # error instead and if/when we find one, we will then implement processing of the shared_state value.
 45        if shared_state:
 46            raise ValueError(
 47                "Received a GLOBAL AirbyteStateMessage that contains a shared_state. This library only ever generates per-STREAM "
 48                "STATE messages so this was not generated by this connector. This must be an orchestrator or platform error. GLOBAL "
 49                "state messages with shared_state will not be processed correctly. "
 50            )
 51        self.per_stream_states = per_stream_states
 52
 53    def get_stream_state(
 54        self, stream_name: str, namespace: Optional[str]
 55    ) -> MutableMapping[str, Any]:
 56        """
 57        Retrieves the state of a given stream based on its descriptor (name + namespace).
 58        :param stream_name: Name of the stream being fetched
 59        :param namespace: Namespace of the stream being fetched
 60        :return: The per-stream state for a stream
 61        """
 62        stream_state: AirbyteStateBlob | None = self.per_stream_states.get(
 63            HashableStreamDescriptor(name=stream_name, namespace=namespace)
 64        )
 65        if stream_state:
 66            return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()})
 67        return {}
 68
 69    def update_state_for_stream(
 70        self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]
 71    ) -> None:
 72        """
 73        Overwrites the state blob of a specific stream based on the provided stream name and optional namespace
 74        :param stream_name: The name of the stream whose state is being updated
 75        :param namespace: The namespace of the stream if it exists
 76        :param value: A stream state mapping that is being updated for a stream
 77        """
 78        stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
 79        self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value)
 80
 81    def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage:
 82        """
 83        Generates an AirbyteMessage using the current per-stream state of a specified stream
 84        :param stream_name: The name of the stream for the message that is being created
 85        :param namespace: The namespace of the stream for the message that is being created
 86        :return: The Airbyte state message to be emitted by the connector during a sync
 87        """
 88        hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
 89        stream_state = self.per_stream_states.get(hashable_descriptor) or AirbyteStateBlob()
 90
 91        return AirbyteMessage(
 92            type=MessageType.STATE,
 93            state=AirbyteStateMessage(
 94                type=AirbyteStateType.STREAM,
 95                stream=AirbyteStreamState(
 96                    stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace),
 97                    stream_state=stream_state,
 98                ),
 99            ),
100        )
101
102    @classmethod
103    def _extract_from_state_message(
104        cls,
105        state: Optional[List[AirbyteStateMessage]],
106    ) -> Tuple[
107        Optional[AirbyteStateBlob],
108        MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]],
109    ]:
110        """
111        Takes an incoming list of state messages or a global state message and extracts state attributes according to
112        type which can then be assigned to the new state manager being instantiated
113        :param state: The incoming state input
114        :return: A tuple of shared state and per stream state assembled from the incoming state list
115        """
116        if state is None:
117            return None, {}
118
119        is_global = cls._is_global_state(state)
120
121        if is_global:
122            # We already validate that this is a global state message, not None:
123            global_state = cast(AirbyteGlobalState, state[0].global_)
124            # global_state has shared_state, also not None:
125            shared_state: AirbyteStateBlob = cast(
126                AirbyteStateBlob, copy.deepcopy(global_state.shared_state, {})
127            )
128            streams = {
129                HashableStreamDescriptor(
130                    name=per_stream_state.stream_descriptor.name,
131                    namespace=per_stream_state.stream_descriptor.namespace,
132                ): per_stream_state.stream_state
133                for per_stream_state in global_state.stream_states  # type: ignore[union-attr] # global_state has shared_state
134            }
135            return shared_state, streams
136        else:
137            streams = {
138                HashableStreamDescriptor(
139                    name=per_stream_state.stream.stream_descriptor.name,  # type: ignore[union-attr] # stream has stream_descriptor
140                    namespace=per_stream_state.stream.stream_descriptor.namespace,  # type: ignore[union-attr] # stream has stream_descriptor
141                ): per_stream_state.stream.stream_state  # type: ignore[union-attr] # stream has stream_state
142                for per_stream_state in state
143                if per_stream_state.type == AirbyteStateType.STREAM
144                and hasattr(per_stream_state, "stream")  # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True
145            }
146            return None, streams
147
148    @staticmethod
149    def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
150        return (
151            isinstance(state, List)
152            and len(state) == 1
153            and isinstance(state[0], AirbyteStateMessage)
154            and state[0].type == AirbyteStateType.GLOBAL
155        )
156
157    @staticmethod
158    def _is_per_stream_state(
159        state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]],
160    ) -> bool:
161        return isinstance(state, List)
@dataclass(frozen=True)
class HashableStreamDescriptor:
22@dataclass(frozen=True)
23class HashableStreamDescriptor:
24    """
25    Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and
26    freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests.
27    """
28
29    name: str
30    namespace: Optional[str] = None

Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests.

HashableStreamDescriptor(name: str, namespace: Optional[str] = None)
name: str
namespace: Optional[str] = None
class ConnectorStateManager:
 33class ConnectorStateManager:
 34    """
 35    ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common
 36    interface. It also provides methods to extract and update state
 37    """
 38
 39    def __init__(self, state: Optional[List[AirbyteStateMessage]] = None):
 40        shared_state, per_stream_states = self._extract_from_state_message(state)
 41
 42        # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are
 43        # designed to checkpoint state independently of one another. API sources should never be emitting a state message where
 44        # shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an
 45        # error instead and if/when we find one, we will then implement processing of the shared_state value.
 46        if shared_state:
 47            raise ValueError(
 48                "Received a GLOBAL AirbyteStateMessage that contains a shared_state. This library only ever generates per-STREAM "
 49                "STATE messages so this was not generated by this connector. This must be an orchestrator or platform error. GLOBAL "
 50                "state messages with shared_state will not be processed correctly. "
 51            )
 52        self.per_stream_states = per_stream_states
 53
 54    def get_stream_state(
 55        self, stream_name: str, namespace: Optional[str]
 56    ) -> MutableMapping[str, Any]:
 57        """
 58        Retrieves the state of a given stream based on its descriptor (name + namespace).
 59        :param stream_name: Name of the stream being fetched
 60        :param namespace: Namespace of the stream being fetched
 61        :return: The per-stream state for a stream
 62        """
 63        stream_state: AirbyteStateBlob | None = self.per_stream_states.get(
 64            HashableStreamDescriptor(name=stream_name, namespace=namespace)
 65        )
 66        if stream_state:
 67            return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()})
 68        return {}
 69
 70    def update_state_for_stream(
 71        self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]
 72    ) -> None:
 73        """
 74        Overwrites the state blob of a specific stream based on the provided stream name and optional namespace
 75        :param stream_name: The name of the stream whose state is being updated
 76        :param namespace: The namespace of the stream if it exists
 77        :param value: A stream state mapping that is being updated for a stream
 78        """
 79        stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
 80        self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value)
 81
 82    def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage:
 83        """
 84        Generates an AirbyteMessage using the current per-stream state of a specified stream
 85        :param stream_name: The name of the stream for the message that is being created
 86        :param namespace: The namespace of the stream for the message that is being created
 87        :return: The Airbyte state message to be emitted by the connector during a sync
 88        """
 89        hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
 90        stream_state = self.per_stream_states.get(hashable_descriptor) or AirbyteStateBlob()
 91
 92        return AirbyteMessage(
 93            type=MessageType.STATE,
 94            state=AirbyteStateMessage(
 95                type=AirbyteStateType.STREAM,
 96                stream=AirbyteStreamState(
 97                    stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace),
 98                    stream_state=stream_state,
 99                ),
100            ),
101        )
102
103    @classmethod
104    def _extract_from_state_message(
105        cls,
106        state: Optional[List[AirbyteStateMessage]],
107    ) -> Tuple[
108        Optional[AirbyteStateBlob],
109        MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]],
110    ]:
111        """
112        Takes an incoming list of state messages or a global state message and extracts state attributes according to
113        type which can then be assigned to the new state manager being instantiated
114        :param state: The incoming state input
115        :return: A tuple of shared state and per stream state assembled from the incoming state list
116        """
117        if state is None:
118            return None, {}
119
120        is_global = cls._is_global_state(state)
121
122        if is_global:
123            # We already validate that this is a global state message, not None:
124            global_state = cast(AirbyteGlobalState, state[0].global_)
125            # global_state has shared_state, also not None:
126            shared_state: AirbyteStateBlob = cast(
127                AirbyteStateBlob, copy.deepcopy(global_state.shared_state, {})
128            )
129            streams = {
130                HashableStreamDescriptor(
131                    name=per_stream_state.stream_descriptor.name,
132                    namespace=per_stream_state.stream_descriptor.namespace,
133                ): per_stream_state.stream_state
134                for per_stream_state in global_state.stream_states  # type: ignore[union-attr] # global_state has shared_state
135            }
136            return shared_state, streams
137        else:
138            streams = {
139                HashableStreamDescriptor(
140                    name=per_stream_state.stream.stream_descriptor.name,  # type: ignore[union-attr] # stream has stream_descriptor
141                    namespace=per_stream_state.stream.stream_descriptor.namespace,  # type: ignore[union-attr] # stream has stream_descriptor
142                ): per_stream_state.stream.stream_state  # type: ignore[union-attr] # stream has stream_state
143                for per_stream_state in state
144                if per_stream_state.type == AirbyteStateType.STREAM
145                and hasattr(per_stream_state, "stream")  # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True
146            }
147            return None, streams
148
149    @staticmethod
150    def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
151        return (
152            isinstance(state, List)
153            and len(state) == 1
154            and isinstance(state[0], AirbyteStateMessage)
155            and state[0].type == AirbyteStateType.GLOBAL
156        )
157
158    @staticmethod
159    def _is_per_stream_state(
160        state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]],
161    ) -> bool:
162        return isinstance(state, List)

ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common interface. It also provides methods to extract and update state

ConnectorStateManager( state: Optional[List[airbyte_cdk.models.airbyte_protocol.AirbyteStateMessage]] = None)
39    def __init__(self, state: Optional[List[AirbyteStateMessage]] = None):
40        shared_state, per_stream_states = self._extract_from_state_message(state)
41
42        # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are
43        # designed to checkpoint state independently of one another. API sources should never be emitting a state message where
44        # shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an
45        # error instead and if/when we find one, we will then implement processing of the shared_state value.
46        if shared_state:
47            raise ValueError(
48                "Received a GLOBAL AirbyteStateMessage that contains a shared_state. This library only ever generates per-STREAM "
49                "STATE messages so this was not generated by this connector. This must be an orchestrator or platform error. GLOBAL "
50                "state messages with shared_state will not be processed correctly. "
51            )
52        self.per_stream_states = per_stream_states
per_stream_states
def get_stream_state( self, stream_name: str, namespace: Optional[str]) -> MutableMapping[str, Any]:
54    def get_stream_state(
55        self, stream_name: str, namespace: Optional[str]
56    ) -> MutableMapping[str, Any]:
57        """
58        Retrieves the state of a given stream based on its descriptor (name + namespace).
59        :param stream_name: Name of the stream being fetched
60        :param namespace: Namespace of the stream being fetched
61        :return: The per-stream state for a stream
62        """
63        stream_state: AirbyteStateBlob | None = self.per_stream_states.get(
64            HashableStreamDescriptor(name=stream_name, namespace=namespace)
65        )
66        if stream_state:
67            return copy.deepcopy({k: v for k, v in stream_state.__dict__.items()})
68        return {}

Retrieves the state of a given stream based on its descriptor (name + namespace).

Parameters
  • stream_name: Name of the stream being fetched
  • namespace: Namespace of the stream being fetched
Returns

The per-stream state for a stream

def update_state_for_stream( self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]) -> None:
70    def update_state_for_stream(
71        self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]
72    ) -> None:
73        """
74        Overwrites the state blob of a specific stream based on the provided stream name and optional namespace
75        :param stream_name: The name of the stream whose state is being updated
76        :param namespace: The namespace of the stream if it exists
77        :param value: A stream state mapping that is being updated for a stream
78        """
79        stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
80        self.per_stream_states[stream_descriptor] = AirbyteStateBlob(value)

Overwrites the state blob of a specific stream based on the provided stream name and optional namespace

Parameters
  • stream_name: The name of the stream whose state is being updated
  • namespace: The namespace of the stream if it exists
  • value: A stream state mapping that is being updated for a stream
def create_state_message( self, stream_name: str, namespace: Optional[str]) -> airbyte_cdk.AirbyteMessage:
 82    def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage:
 83        """
 84        Generates an AirbyteMessage using the current per-stream state of a specified stream
 85        :param stream_name: The name of the stream for the message that is being created
 86        :param namespace: The namespace of the stream for the message that is being created
 87        :return: The Airbyte state message to be emitted by the connector during a sync
 88        """
 89        hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
 90        stream_state = self.per_stream_states.get(hashable_descriptor) or AirbyteStateBlob()
 91
 92        return AirbyteMessage(
 93            type=MessageType.STATE,
 94            state=AirbyteStateMessage(
 95                type=AirbyteStateType.STREAM,
 96                stream=AirbyteStreamState(
 97                    stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace),
 98                    stream_state=stream_state,
 99                ),
100            ),
101        )

Generates an AirbyteMessage using the current per-stream state of a specified stream

Parameters
  • stream_name: The name of the stream for the message that is being created
  • namespace: The namespace of the stream for the message that is being created
Returns

The Airbyte state message to be emitted by the connector during a sync