airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter

  1#
  2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
  3#
  4
  5from abc import ABC, abstractmethod
  6from enum import Enum
  7from typing import TYPE_CHECKING, Any, Callable, List, MutableMapping, Optional, Tuple
  8
  9if TYPE_CHECKING:
 10    from airbyte_cdk.sources.streams.concurrent.cursor import CursorField
 11
 12
 13class ConcurrencyCompatibleStateType(Enum):
 14    date_range = "date-range"
 15    integer = "integer"
 16
 17
 18class AbstractStreamStateConverter(ABC):
 19    START_KEY = "start"
 20    END_KEY = "end"
 21    MOST_RECENT_RECORD_KEY = "most_recent_cursor_value"
 22
 23    @abstractmethod
 24    def _from_state_message(self, value: Any) -> Any:
 25        pass
 26
 27    @abstractmethod
 28    def _to_state_message(self, value: Any) -> Any:
 29        pass
 30
 31    def __init__(self, is_sequential_state: bool = True):
 32        self._is_sequential_state = is_sequential_state
 33
 34    def convert_to_state_message(
 35        self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any]
 36    ) -> MutableMapping[str, Any]:
 37        """
 38        Convert the state message from the concurrency-compatible format to the stream's original format.
 39
 40        e.g.
 41        { "created": "2021-01-18T21:18:20.000Z" }
 42        """
 43        if self.is_state_message_compatible(stream_state) and self._is_sequential_state:
 44            legacy_state = stream_state.get("legacy", {})
 45            latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", []))
 46            if latest_complete_time is not None:
 47                legacy_state.update(
 48                    {cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)}
 49                )
 50            return legacy_state or {}
 51        else:
 52            return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range)
 53
 54    def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> Any:
 55        """
 56        Get the latest time before which all records have been processed.
 57        """
 58        if not slices:
 59            raise RuntimeError(
 60                "Expected at least one slice but there were none. This is unexpected; please contact Support."
 61            )
 62        merged_intervals = self.merge_intervals(slices)
 63        first_interval = merged_intervals[0]
 64
 65        return first_interval.get("most_recent_cursor_value") or first_interval[self.START_KEY]
 66
 67    def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
 68        """
 69        Perform any transformations needed for compatibility with the converter.
 70        """
 71        for stream_slice in state.get("slices", []):
 72            stream_slice[self.START_KEY] = self._from_state_message(stream_slice[self.START_KEY])
 73            stream_slice[self.END_KEY] = self._from_state_message(stream_slice[self.END_KEY])
 74            if self.MOST_RECENT_RECORD_KEY in stream_slice:
 75                stream_slice[self.MOST_RECENT_RECORD_KEY] = self._from_state_message(
 76                    stream_slice[self.MOST_RECENT_RECORD_KEY]
 77                )
 78        return state
 79
 80    def serialize(
 81        self, state: MutableMapping[str, Any], state_type: ConcurrencyCompatibleStateType
 82    ) -> MutableMapping[str, Any]:
 83        """
 84        Perform any transformations needed for compatibility with the converter.
 85        """
 86        serialized_slices = []
 87        for stream_slice in state.get("slices", []):
 88            serialized_slice = {
 89                self.START_KEY: self._to_state_message(stream_slice[self.START_KEY]),
 90                self.END_KEY: self._to_state_message(stream_slice[self.END_KEY]),
 91            }
 92            if stream_slice.get(self.MOST_RECENT_RECORD_KEY):
 93                serialized_slice[self.MOST_RECENT_RECORD_KEY] = self._to_state_message(
 94                    stream_slice[self.MOST_RECENT_RECORD_KEY]
 95                )
 96            serialized_slices.append(serialized_slice)
 97        return {"slices": serialized_slices, "state_type": state_type.value}
 98
 99    @staticmethod
100    def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool:
101        return bool(state) and state.get("state_type") in [
102            t.value for t in ConcurrencyCompatibleStateType
103        ]
104
105    @abstractmethod
106    def convert_from_sequential_state(
107        self,
108        cursor_field: "CursorField",  # to deprecate as it is only needed for sequential state
109        stream_state: MutableMapping[str, Any],
110        start: Optional[Any],
111    ) -> Tuple[Any, MutableMapping[str, Any]]:
112        """
113        Convert the state message to the format required by the ConcurrentCursor.
114
115        e.g.
116        {
117            "state_type": ConcurrencyCompatibleStateType.date_range.value,
118            "metadata": { … },
119            "slices": [
120                {starts: 0, end: 1617030403, finished_processing: true}]
121        }
122        """
123        ...
124
125    @abstractmethod
126    def increment(self, value: Any) -> Any:
127        """
128        Increment a timestamp by a single unit.
129        """
130        ...
131
132    @abstractmethod
133    def output_format(self, value: Any) -> Any:
134        """
135        Convert the cursor value type to a JSON valid type.
136        """
137        ...
138
139    def merge_intervals(
140        self, intervals: List[MutableMapping[str, Any]]
141    ) -> List[MutableMapping[str, Any]]:
142        """
143        Compute and return a list of merged intervals.
144
145        Intervals may be merged if the start time of the second interval is 1 unit or less (as defined by the
146        `increment` method) than the end time of the first interval.
147        """
148        if not intervals:
149            return []
150
151        sorted_intervals = sorted(
152            intervals, key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY])
153        )
154        merged_intervals = [sorted_intervals[0]]
155
156        for current_interval in sorted_intervals[1:]:
157            last_interval = merged_intervals[-1]
158            last_interval_end = last_interval[self.END_KEY]
159            current_interval_start = current_interval[self.START_KEY]
160
161            if self.increment(last_interval_end) >= current_interval_start:
162                last_interval[self.END_KEY] = max(last_interval_end, current_interval[self.END_KEY])
163                last_interval_cursor_value = last_interval.get("most_recent_cursor_value")
164                current_interval_cursor_value = current_interval.get("most_recent_cursor_value")
165
166                last_interval["most_recent_cursor_value"] = (
167                    max(current_interval_cursor_value, last_interval_cursor_value)
168                    if current_interval_cursor_value and last_interval_cursor_value
169                    else current_interval_cursor_value or last_interval_cursor_value
170                )
171            else:
172                # Add a new interval if no overlap
173                merged_intervals.append(current_interval)
174
175        return merged_intervals
176
177    @abstractmethod
178    def parse_value(self, value: Any) -> Any:
179        """
180        Parse the value of the cursor field into a comparable value.
181        """
182        ...
183
184    @property
185    @abstractmethod
186    def zero_value(self) -> Any: ...
class ConcurrencyCompatibleStateType(enum.Enum):
14class ConcurrencyCompatibleStateType(Enum):
15    date_range = "date-range"
16    integer = "integer"

An enumeration.

date_range = <ConcurrencyCompatibleStateType.date_range: 'date-range'>
class AbstractStreamStateConverter(abc.ABC):
 19class AbstractStreamStateConverter(ABC):
 20    START_KEY = "start"
 21    END_KEY = "end"
 22    MOST_RECENT_RECORD_KEY = "most_recent_cursor_value"
 23
 24    @abstractmethod
 25    def _from_state_message(self, value: Any) -> Any:
 26        pass
 27
 28    @abstractmethod
 29    def _to_state_message(self, value: Any) -> Any:
 30        pass
 31
 32    def __init__(self, is_sequential_state: bool = True):
 33        self._is_sequential_state = is_sequential_state
 34
 35    def convert_to_state_message(
 36        self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any]
 37    ) -> MutableMapping[str, Any]:
 38        """
 39        Convert the state message from the concurrency-compatible format to the stream's original format.
 40
 41        e.g.
 42        { "created": "2021-01-18T21:18:20.000Z" }
 43        """
 44        if self.is_state_message_compatible(stream_state) and self._is_sequential_state:
 45            legacy_state = stream_state.get("legacy", {})
 46            latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", []))
 47            if latest_complete_time is not None:
 48                legacy_state.update(
 49                    {cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)}
 50                )
 51            return legacy_state or {}
 52        else:
 53            return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range)
 54
 55    def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> Any:
 56        """
 57        Get the latest time before which all records have been processed.
 58        """
 59        if not slices:
 60            raise RuntimeError(
 61                "Expected at least one slice but there were none. This is unexpected; please contact Support."
 62            )
 63        merged_intervals = self.merge_intervals(slices)
 64        first_interval = merged_intervals[0]
 65
 66        return first_interval.get("most_recent_cursor_value") or first_interval[self.START_KEY]
 67
 68    def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
 69        """
 70        Perform any transformations needed for compatibility with the converter.
 71        """
 72        for stream_slice in state.get("slices", []):
 73            stream_slice[self.START_KEY] = self._from_state_message(stream_slice[self.START_KEY])
 74            stream_slice[self.END_KEY] = self._from_state_message(stream_slice[self.END_KEY])
 75            if self.MOST_RECENT_RECORD_KEY in stream_slice:
 76                stream_slice[self.MOST_RECENT_RECORD_KEY] = self._from_state_message(
 77                    stream_slice[self.MOST_RECENT_RECORD_KEY]
 78                )
 79        return state
 80
 81    def serialize(
 82        self, state: MutableMapping[str, Any], state_type: ConcurrencyCompatibleStateType
 83    ) -> MutableMapping[str, Any]:
 84        """
 85        Perform any transformations needed for compatibility with the converter.
 86        """
 87        serialized_slices = []
 88        for stream_slice in state.get("slices", []):
 89            serialized_slice = {
 90                self.START_KEY: self._to_state_message(stream_slice[self.START_KEY]),
 91                self.END_KEY: self._to_state_message(stream_slice[self.END_KEY]),
 92            }
 93            if stream_slice.get(self.MOST_RECENT_RECORD_KEY):
 94                serialized_slice[self.MOST_RECENT_RECORD_KEY] = self._to_state_message(
 95                    stream_slice[self.MOST_RECENT_RECORD_KEY]
 96                )
 97            serialized_slices.append(serialized_slice)
 98        return {"slices": serialized_slices, "state_type": state_type.value}
 99
100    @staticmethod
101    def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool:
102        return bool(state) and state.get("state_type") in [
103            t.value for t in ConcurrencyCompatibleStateType
104        ]
105
106    @abstractmethod
107    def convert_from_sequential_state(
108        self,
109        cursor_field: "CursorField",  # to deprecate as it is only needed for sequential state
110        stream_state: MutableMapping[str, Any],
111        start: Optional[Any],
112    ) -> Tuple[Any, MutableMapping[str, Any]]:
113        """
114        Convert the state message to the format required by the ConcurrentCursor.
115
116        e.g.
117        {
118            "state_type": ConcurrencyCompatibleStateType.date_range.value,
119            "metadata": { … },
120            "slices": [
121                {starts: 0, end: 1617030403, finished_processing: true}]
122        }
123        """
124        ...
125
126    @abstractmethod
127    def increment(self, value: Any) -> Any:
128        """
129        Increment a timestamp by a single unit.
130        """
131        ...
132
133    @abstractmethod
134    def output_format(self, value: Any) -> Any:
135        """
136        Convert the cursor value type to a JSON valid type.
137        """
138        ...
139
140    def merge_intervals(
141        self, intervals: List[MutableMapping[str, Any]]
142    ) -> List[MutableMapping[str, Any]]:
143        """
144        Compute and return a list of merged intervals.
145
146        Intervals may be merged if the start time of the second interval is 1 unit or less (as defined by the
147        `increment` method) than the end time of the first interval.
148        """
149        if not intervals:
150            return []
151
152        sorted_intervals = sorted(
153            intervals, key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY])
154        )
155        merged_intervals = [sorted_intervals[0]]
156
157        for current_interval in sorted_intervals[1:]:
158            last_interval = merged_intervals[-1]
159            last_interval_end = last_interval[self.END_KEY]
160            current_interval_start = current_interval[self.START_KEY]
161
162            if self.increment(last_interval_end) >= current_interval_start:
163                last_interval[self.END_KEY] = max(last_interval_end, current_interval[self.END_KEY])
164                last_interval_cursor_value = last_interval.get("most_recent_cursor_value")
165                current_interval_cursor_value = current_interval.get("most_recent_cursor_value")
166
167                last_interval["most_recent_cursor_value"] = (
168                    max(current_interval_cursor_value, last_interval_cursor_value)
169                    if current_interval_cursor_value and last_interval_cursor_value
170                    else current_interval_cursor_value or last_interval_cursor_value
171                )
172            else:
173                # Add a new interval if no overlap
174                merged_intervals.append(current_interval)
175
176        return merged_intervals
177
178    @abstractmethod
179    def parse_value(self, value: Any) -> Any:
180        """
181        Parse the value of the cursor field into a comparable value.
182        """
183        ...
184
185    @property
186    @abstractmethod
187    def zero_value(self) -> Any: ...

Helper class that provides a standard way to create an ABC using inheritance.

START_KEY = 'start'
END_KEY = 'end'
MOST_RECENT_RECORD_KEY = 'most_recent_cursor_value'
def convert_to_state_message( self, cursor_field: airbyte_cdk.CursorField, stream_state: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
35    def convert_to_state_message(
36        self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any]
37    ) -> MutableMapping[str, Any]:
38        """
39        Convert the state message from the concurrency-compatible format to the stream's original format.
40
41        e.g.
42        { "created": "2021-01-18T21:18:20.000Z" }
43        """
44        if self.is_state_message_compatible(stream_state) and self._is_sequential_state:
45            legacy_state = stream_state.get("legacy", {})
46            latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", []))
47            if latest_complete_time is not None:
48                legacy_state.update(
49                    {cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)}
50                )
51            return legacy_state or {}
52        else:
53            return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range)

Convert the state message from the concurrency-compatible format to the stream's original format.

e.g. { "created": "2021-01-18T21:18:20.000Z" }

def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
68    def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
69        """
70        Perform any transformations needed for compatibility with the converter.
71        """
72        for stream_slice in state.get("slices", []):
73            stream_slice[self.START_KEY] = self._from_state_message(stream_slice[self.START_KEY])
74            stream_slice[self.END_KEY] = self._from_state_message(stream_slice[self.END_KEY])
75            if self.MOST_RECENT_RECORD_KEY in stream_slice:
76                stream_slice[self.MOST_RECENT_RECORD_KEY] = self._from_state_message(
77                    stream_slice[self.MOST_RECENT_RECORD_KEY]
78                )
79        return state

Perform any transformations needed for compatibility with the converter.

def serialize( self, state: MutableMapping[str, Any], state_type: ConcurrencyCompatibleStateType) -> MutableMapping[str, Any]:
81    def serialize(
82        self, state: MutableMapping[str, Any], state_type: ConcurrencyCompatibleStateType
83    ) -> MutableMapping[str, Any]:
84        """
85        Perform any transformations needed for compatibility with the converter.
86        """
87        serialized_slices = []
88        for stream_slice in state.get("slices", []):
89            serialized_slice = {
90                self.START_KEY: self._to_state_message(stream_slice[self.START_KEY]),
91                self.END_KEY: self._to_state_message(stream_slice[self.END_KEY]),
92            }
93            if stream_slice.get(self.MOST_RECENT_RECORD_KEY):
94                serialized_slice[self.MOST_RECENT_RECORD_KEY] = self._to_state_message(
95                    stream_slice[self.MOST_RECENT_RECORD_KEY]
96                )
97            serialized_slices.append(serialized_slice)
98        return {"slices": serialized_slices, "state_type": state_type.value}

Perform any transformations needed for compatibility with the converter.

@staticmethod
def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool:
100    @staticmethod
101    def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool:
102        return bool(state) and state.get("state_type") in [
103            t.value for t in ConcurrencyCompatibleStateType
104        ]
@abstractmethod
def convert_from_sequential_state( self, cursor_field: airbyte_cdk.CursorField, stream_state: MutableMapping[str, Any], start: Optional[Any]) -> Tuple[Any, MutableMapping[str, Any]]:
106    @abstractmethod
107    def convert_from_sequential_state(
108        self,
109        cursor_field: "CursorField",  # to deprecate as it is only needed for sequential state
110        stream_state: MutableMapping[str, Any],
111        start: Optional[Any],
112    ) -> Tuple[Any, MutableMapping[str, Any]]:
113        """
114        Convert the state message to the format required by the ConcurrentCursor.
115
116        e.g.
117        {
118            "state_type": ConcurrencyCompatibleStateType.date_range.value,
119            "metadata": { … },
120            "slices": [
121                {starts: 0, end: 1617030403, finished_processing: true}]
122        }
123        """
124        ...

Convert the state message to the format required by the ConcurrentCursor.

e.g. { "state_type": ConcurrencyCompatibleStateType.date_range.value, "metadata": { … }, "slices": [ {starts: 0, end: 1617030403, finished_processing: true}] }

@abstractmethod
def increment(self, value: Any) -> Any:
126    @abstractmethod
127    def increment(self, value: Any) -> Any:
128        """
129        Increment a timestamp by a single unit.
130        """
131        ...

Increment a timestamp by a single unit.

@abstractmethod
def output_format(self, value: Any) -> Any:
133    @abstractmethod
134    def output_format(self, value: Any) -> Any:
135        """
136        Convert the cursor value type to a JSON valid type.
137        """
138        ...

Convert the cursor value type to a JSON valid type.

def merge_intervals( self, intervals: List[MutableMapping[str, Any]]) -> List[MutableMapping[str, Any]]:
140    def merge_intervals(
141        self, intervals: List[MutableMapping[str, Any]]
142    ) -> List[MutableMapping[str, Any]]:
143        """
144        Compute and return a list of merged intervals.
145
146        Intervals may be merged if the start time of the second interval is 1 unit or less (as defined by the
147        `increment` method) than the end time of the first interval.
148        """
149        if not intervals:
150            return []
151
152        sorted_intervals = sorted(
153            intervals, key=lambda interval: (interval[self.START_KEY], interval[self.END_KEY])
154        )
155        merged_intervals = [sorted_intervals[0]]
156
157        for current_interval in sorted_intervals[1:]:
158            last_interval = merged_intervals[-1]
159            last_interval_end = last_interval[self.END_KEY]
160            current_interval_start = current_interval[self.START_KEY]
161
162            if self.increment(last_interval_end) >= current_interval_start:
163                last_interval[self.END_KEY] = max(last_interval_end, current_interval[self.END_KEY])
164                last_interval_cursor_value = last_interval.get("most_recent_cursor_value")
165                current_interval_cursor_value = current_interval.get("most_recent_cursor_value")
166
167                last_interval["most_recent_cursor_value"] = (
168                    max(current_interval_cursor_value, last_interval_cursor_value)
169                    if current_interval_cursor_value and last_interval_cursor_value
170                    else current_interval_cursor_value or last_interval_cursor_value
171                )
172            else:
173                # Add a new interval if no overlap
174                merged_intervals.append(current_interval)
175
176        return merged_intervals

Compute and return a list of merged intervals.

Intervals may be merged if the start time of the second interval is 1 unit or less (as defined by the increment method) than the end time of the first interval.

@abstractmethod
def parse_value(self, value: Any) -> Any:
178    @abstractmethod
179    def parse_value(self, value: Any) -> Any:
180        """
181        Parse the value of the cursor field into a comparable value.
182        """
183        ...

Parse the value of the cursor field into a comparable value.

zero_value: Any
185    @property
186    @abstractmethod
187    def zero_value(self) -> Any: ...