airbyte_cdk.utils

 1#
 2# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
 3#
 4
 5from .is_cloud_environment import is_cloud_environment
 6from .print_buffer import PrintBuffer
 7from .schema_inferrer import SchemaInferrer
 8from .traced_exception import AirbyteTracedException
 9
10__all__ = ["AirbyteTracedException", "SchemaInferrer", "is_cloud_environment", "PrintBuffer"]
class AirbyteTracedException(builtins.Exception):
 26class AirbyteTracedException(Exception):
 27    """
 28    An exception that should be emitted as an AirbyteTraceMessage
 29    """
 30
 31    def __init__(
 32        self,
 33        internal_message: Optional[str] = None,
 34        message: Optional[str] = None,
 35        failure_type: FailureType = FailureType.system_error,
 36        exception: Optional[BaseException] = None,
 37        stream_descriptor: Optional[StreamDescriptor] = None,
 38    ):
 39        """
 40        :param internal_message: the internal error that caused the failure
 41        :param message: a user-friendly message that indicates the cause of the error
 42        :param failure_type: the type of error
 43        :param exception: the exception that caused the error, from which the stack trace should be retrieved
 44        :param stream_descriptor: describe the stream from which the exception comes from
 45        """
 46        self.internal_message = internal_message
 47        self.message = message
 48        self.failure_type = failure_type
 49        self._exception = exception
 50        self._stream_descriptor = stream_descriptor
 51        super().__init__(internal_message)
 52
 53    def as_airbyte_message(
 54        self, stream_descriptor: Optional[StreamDescriptor] = None
 55    ) -> AirbyteMessage:
 56        """
 57        Builds an AirbyteTraceMessage from the exception
 58
 59        :param stream_descriptor is deprecated, please use the stream_description in `__init__ or `from_exception`. If many
 60          stream_descriptors are defined, the one from `as_airbyte_message` will be discarded.
 61        """
 62        now_millis = time.time_ns() // 1_000_000
 63
 64        trace_exc = self._exception or self
 65        stack_trace_str = "".join(traceback.TracebackException.from_exception(trace_exc).format())
 66
 67        trace_message = AirbyteTraceMessage(
 68            type=TraceType.ERROR,
 69            emitted_at=now_millis,
 70            error=AirbyteErrorTraceMessage(
 71                message=self.message
 72                or "Something went wrong in the connector. See the logs for more details.",
 73                internal_message=self.internal_message,
 74                failure_type=self.failure_type,
 75                stack_trace=stack_trace_str,
 76                stream_descriptor=self._stream_descriptor
 77                if self._stream_descriptor is not None
 78                else stream_descriptor,
 79            ),
 80        )
 81
 82        return AirbyteMessage(type=MessageType.TRACE, trace=trace_message)
 83
 84    def as_connection_status_message(self) -> Optional[AirbyteMessage]:
 85        if self.failure_type == FailureType.config_error:
 86            return AirbyteMessage(
 87                type=MessageType.CONNECTION_STATUS,
 88                connectionStatus=AirbyteConnectionStatus(
 89                    status=Status.FAILED, message=self.message
 90                ),
 91            )
 92        return None
 93
 94    def emit_message(self) -> None:
 95        """
 96        Prints the exception as an AirbyteTraceMessage.
 97        Note that this will be called automatically on uncaught exceptions when using the airbyte_cdk entrypoint.
 98        """
 99        message = orjson.dumps(AirbyteMessageSerializer.dump(self.as_airbyte_message())).decode()
100        filtered_message = filter_secrets(message)
101        print(filtered_message)
102
103    @classmethod
104    def from_exception(
105        cls,
106        exc: BaseException,
107        stream_descriptor: Optional[StreamDescriptor] = None,
108        *args: Any,
109        **kwargs: Any,
110    ) -> "AirbyteTracedException":
111        """
112        Helper to create an AirbyteTracedException from an existing exception
113        :param exc: the exception that caused the error
114        :param stream_descriptor: describe the stream from which the exception comes from
115        """
116        return cls(
117            internal_message=str(exc),
118            exception=exc,
119            stream_descriptor=stream_descriptor,
120            *args,
121            **kwargs,
122        )  # type: ignore  # ignoring because of args and kwargs
123
124    def as_sanitized_airbyte_message(
125        self, stream_descriptor: Optional[StreamDescriptor] = None
126    ) -> AirbyteMessage:
127        """
128        Builds an AirbyteTraceMessage from the exception and sanitizes any secrets from the message body
129
130        :param stream_descriptor is deprecated, please use the stream_description in `__init__ or `from_exception`. If many
131          stream_descriptors are defined, the one from `as_sanitized_airbyte_message` will be discarded.
132        """
133        error_message = self.as_airbyte_message(stream_descriptor=stream_descriptor)
134        if error_message.trace.error.message:  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
135            error_message.trace.error.message = filter_secrets(  # type: ignore[union-attr]
136                error_message.trace.error.message,  # type: ignore[union-attr]
137            )
138        if error_message.trace.error.internal_message:  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
139            error_message.trace.error.internal_message = filter_secrets(  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
140                error_message.trace.error.internal_message  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
141            )
142        if error_message.trace.error.stack_trace:  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
143            error_message.trace.error.stack_trace = filter_secrets(  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
144                error_message.trace.error.stack_trace  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
145            )
146        return error_message

An exception that should be emitted as an AirbyteTraceMessage

AirbyteTracedException( internal_message: Optional[str] = None, message: Optional[str] = None, failure_type: airbyte_protocol_dataclasses.models.airbyte_protocol.FailureType = <FailureType.system_error: 'system_error'>, exception: Optional[BaseException] = None, stream_descriptor: Optional[airbyte_protocol_dataclasses.models.airbyte_protocol.StreamDescriptor] = None)
31    def __init__(
32        self,
33        internal_message: Optional[str] = None,
34        message: Optional[str] = None,
35        failure_type: FailureType = FailureType.system_error,
36        exception: Optional[BaseException] = None,
37        stream_descriptor: Optional[StreamDescriptor] = None,
38    ):
39        """
40        :param internal_message: the internal error that caused the failure
41        :param message: a user-friendly message that indicates the cause of the error
42        :param failure_type: the type of error
43        :param exception: the exception that caused the error, from which the stack trace should be retrieved
44        :param stream_descriptor: describe the stream from which the exception comes from
45        """
46        self.internal_message = internal_message
47        self.message = message
48        self.failure_type = failure_type
49        self._exception = exception
50        self._stream_descriptor = stream_descriptor
51        super().__init__(internal_message)
Parameters
  • internal_message: the internal error that caused the failure
  • message: a user-friendly message that indicates the cause of the error
  • failure_type: the type of error
  • exception: the exception that caused the error, from which the stack trace should be retrieved
  • stream_descriptor: describe the stream from which the exception comes from
internal_message
message
failure_type
def as_airbyte_message( self, stream_descriptor: Optional[airbyte_protocol_dataclasses.models.airbyte_protocol.StreamDescriptor] = None) -> airbyte_cdk.AirbyteMessage:
53    def as_airbyte_message(
54        self, stream_descriptor: Optional[StreamDescriptor] = None
55    ) -> AirbyteMessage:
56        """
57        Builds an AirbyteTraceMessage from the exception
58
59        :param stream_descriptor is deprecated, please use the stream_description in `__init__ or `from_exception`. If many
60          stream_descriptors are defined, the one from `as_airbyte_message` will be discarded.
61        """
62        now_millis = time.time_ns() // 1_000_000
63
64        trace_exc = self._exception or self
65        stack_trace_str = "".join(traceback.TracebackException.from_exception(trace_exc).format())
66
67        trace_message = AirbyteTraceMessage(
68            type=TraceType.ERROR,
69            emitted_at=now_millis,
70            error=AirbyteErrorTraceMessage(
71                message=self.message
72                or "Something went wrong in the connector. See the logs for more details.",
73                internal_message=self.internal_message,
74                failure_type=self.failure_type,
75                stack_trace=stack_trace_str,
76                stream_descriptor=self._stream_descriptor
77                if self._stream_descriptor is not None
78                else stream_descriptor,
79            ),
80        )
81
82        return AirbyteMessage(type=MessageType.TRACE, trace=trace_message)

Builds an AirbyteTraceMessage from the exception

:param stream_descriptor is deprecated, please use the stream_description in __init__ orfrom_exception. If many stream_descriptors are defined, the one fromas_airbyte_message` will be discarded.

def as_connection_status_message(self) -> Optional[airbyte_cdk.AirbyteMessage]:
84    def as_connection_status_message(self) -> Optional[AirbyteMessage]:
85        if self.failure_type == FailureType.config_error:
86            return AirbyteMessage(
87                type=MessageType.CONNECTION_STATUS,
88                connectionStatus=AirbyteConnectionStatus(
89                    status=Status.FAILED, message=self.message
90                ),
91            )
92        return None
def emit_message(self) -> None:
 94    def emit_message(self) -> None:
 95        """
 96        Prints the exception as an AirbyteTraceMessage.
 97        Note that this will be called automatically on uncaught exceptions when using the airbyte_cdk entrypoint.
 98        """
 99        message = orjson.dumps(AirbyteMessageSerializer.dump(self.as_airbyte_message())).decode()
100        filtered_message = filter_secrets(message)
101        print(filtered_message)

Prints the exception as an AirbyteTraceMessage. Note that this will be called automatically on uncaught exceptions when using the airbyte_cdk entrypoint.

@classmethod
def from_exception( cls, exc: BaseException, stream_descriptor: Optional[airbyte_protocol_dataclasses.models.airbyte_protocol.StreamDescriptor] = None, *args: Any, **kwargs: Any) -> AirbyteTracedException:
103    @classmethod
104    def from_exception(
105        cls,
106        exc: BaseException,
107        stream_descriptor: Optional[StreamDescriptor] = None,
108        *args: Any,
109        **kwargs: Any,
110    ) -> "AirbyteTracedException":
111        """
112        Helper to create an AirbyteTracedException from an existing exception
113        :param exc: the exception that caused the error
114        :param stream_descriptor: describe the stream from which the exception comes from
115        """
116        return cls(
117            internal_message=str(exc),
118            exception=exc,
119            stream_descriptor=stream_descriptor,
120            *args,
121            **kwargs,
122        )  # type: ignore  # ignoring because of args and kwargs

Helper to create an AirbyteTracedException from an existing exception

Parameters
  • exc: the exception that caused the error
  • stream_descriptor: describe the stream from which the exception comes from
def as_sanitized_airbyte_message( self, stream_descriptor: Optional[airbyte_protocol_dataclasses.models.airbyte_protocol.StreamDescriptor] = None) -> airbyte_cdk.AirbyteMessage:
124    def as_sanitized_airbyte_message(
125        self, stream_descriptor: Optional[StreamDescriptor] = None
126    ) -> AirbyteMessage:
127        """
128        Builds an AirbyteTraceMessage from the exception and sanitizes any secrets from the message body
129
130        :param stream_descriptor is deprecated, please use the stream_description in `__init__ or `from_exception`. If many
131          stream_descriptors are defined, the one from `as_sanitized_airbyte_message` will be discarded.
132        """
133        error_message = self.as_airbyte_message(stream_descriptor=stream_descriptor)
134        if error_message.trace.error.message:  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
135            error_message.trace.error.message = filter_secrets(  # type: ignore[union-attr]
136                error_message.trace.error.message,  # type: ignore[union-attr]
137            )
138        if error_message.trace.error.internal_message:  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
139            error_message.trace.error.internal_message = filter_secrets(  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
140                error_message.trace.error.internal_message  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
141            )
142        if error_message.trace.error.stack_trace:  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
143            error_message.trace.error.stack_trace = filter_secrets(  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
144                error_message.trace.error.stack_trace  # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has AirbyteTraceMessage
145            )
146        return error_message

Builds an AirbyteTraceMessage from the exception and sanitizes any secrets from the message body

:param stream_descriptor is deprecated, please use the stream_description in __init__ orfrom_exception. If many stream_descriptors are defined, the one fromas_sanitized_airbyte_message` will be discarded.

class SchemaInferrer:
 82class SchemaInferrer:
 83    """
 84    This class is used to infer a JSON schema which fits all the records passed into it
 85    throughout its lifecycle via the accumulate method.
 86
 87    Instances of this class are stateful, meaning they build their inferred schemas
 88    from every record passed into the accumulate method.
 89
 90    """
 91
 92    stream_to_builder: Dict[str, SchemaBuilder]
 93
 94    def __init__(
 95        self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None
 96    ) -> None:
 97        self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder)
 98        self._pk = [] if pk is None else pk
 99        self._cursor_field = [] if cursor_field is None else cursor_field
100
101    def accumulate(self, record: AirbyteRecordMessage) -> None:
102        """Uses the input record to add to the inferred schemas maintained by this object"""
103        self.stream_to_builder[record.stream].add_object(record.data)
104
105    def _null_type_in_any_of(self, node: InferredSchema) -> bool:
106        if _ANY_OF in node:
107            return {_TYPE: _NULL_TYPE} in node[_ANY_OF]
108        else:
109            return False
110
111    def _remove_type_from_any_of(self, node: InferredSchema) -> None:
112        if _ANY_OF in node:
113            node.pop(_TYPE, None)
114
115    def _clean_any_of(self, node: InferredSchema) -> None:
116        if len(node[_ANY_OF]) == 2 and self._null_type_in_any_of(node):
117            real_type = (
118                node[_ANY_OF][1] if node[_ANY_OF][0][_TYPE] == _NULL_TYPE else node[_ANY_OF][0]
119            )
120            node.update(real_type)
121            node[_TYPE] = [node[_TYPE], _NULL_TYPE]
122            node.pop(_ANY_OF)
123        # populate `type` for `anyOf` if it's not present to pass all other checks
124        elif len(node[_ANY_OF]) == 2 and not self._null_type_in_any_of(node):
125            node[_TYPE] = [_NULL_TYPE]
126
127    def _clean_properties(self, node: InferredSchema) -> None:
128        for key, value in list(node[_PROPERTIES].items()):
129            if isinstance(value, dict) and value.get(_TYPE, None) == _NULL_TYPE:
130                node[_PROPERTIES].pop(key)
131            else:
132                self._clean(value)
133
134    def _ensure_null_type_on_top(self, node: InferredSchema) -> None:
135        if isinstance(node[_TYPE], list):
136            if _NULL_TYPE in node[_TYPE]:
137                # we want to make sure null is always at the end as it makes schemas more readable
138                node[_TYPE].remove(_NULL_TYPE)
139            node[_TYPE].append(_NULL_TYPE)
140        else:
141            node[_TYPE] = [node[_TYPE], _NULL_TYPE]
142
143    def _clean(self, node: InferredSchema) -> InferredSchema:
144        """
145        Recursively cleans up a produced schema:
146        - remove anyOf if one of them is just a null value
147        - remove properties of type "null"
148        """
149
150        if isinstance(node, dict):
151            if _ANY_OF in node:
152                self._clean_any_of(node)
153
154            if _PROPERTIES in node and isinstance(node[_PROPERTIES], dict):
155                self._clean_properties(node)
156
157            if _ITEMS in node:
158                self._clean(node[_ITEMS])
159
160            # this check needs to follow the "anyOf" cleaning as it might populate `type`
161            self._ensure_null_type_on_top(node)
162
163        # remove added `type: ["null"]` for `anyOf` nested node
164        self._remove_type_from_any_of(node)
165
166        return node
167
168    def _add_required_properties(self, node: InferredSchema) -> InferredSchema:
169        """
170        This method takes properties that should be marked as required (self._pk and self._cursor_field) and travel the schema to mark every
171        node as required.
172        """
173        # Removing nullable for the root as when we call `_clean`, we make everything nullable
174        node[_TYPE] = _OBJECT_TYPE
175
176        exceptions = []
177        for field in [x for x in [self._pk, self._cursor_field] if x]:
178            try:
179                self._add_fields_as_required(node, field)
180            except SchemaValidationException as exception:
181                exceptions.append(exception)
182
183        if exceptions:
184            raise SchemaValidationException.merge_exceptions(exceptions)
185
186        return node
187
188    def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List[str]]) -> None:
189        """
190        Take a list of nested keys (this list represents a composite key) and travel the schema to mark every node as required.
191        """
192        errors: List[Exception] = []
193
194        for path in composite_key:
195            try:
196                self._add_field_as_required(node, path)
197            except ValueError as exception:
198                errors.append(exception)
199
200        if errors:
201            raise SchemaValidationException(node, errors)
202
203    def _add_field_as_required(
204        self, node: InferredSchema, path: List[str], traveled_path: Optional[List[str]] = None
205    ) -> None:
206        """
207        Take a nested key and travel the schema to mark every node as required.
208        """
209        self._remove_null_from_type(node)
210        if self._is_leaf(path):
211            return
212
213        if not traveled_path:
214            traveled_path = []
215
216        if _PROPERTIES not in node:
217            # This validation is only relevant when `traveled_path` is empty
218            raise ValueError(
219                f"Path {traveled_path} does not refer to an object but is `{node}` and hence {path} can't be marked as required."
220            )
221
222        next_node = path[0]
223        if next_node not in node[_PROPERTIES]:
224            raise ValueError(
225                f"Path {traveled_path} does not have field `{next_node}` in the schema and hence can't be marked as required."
226            )
227
228        if _TYPE not in node:
229            # We do not expect this case to happen but we added a specific error message just in case
230            raise ValueError(
231                f"Unknown schema error: {traveled_path} is expected to have a type but did not. Schema inferrence is probably broken"
232            )
233
234        if node[_TYPE] not in [
235            _OBJECT_TYPE,
236            [_NULL_TYPE, _OBJECT_TYPE],
237            [_OBJECT_TYPE, _NULL_TYPE],
238        ]:
239            raise ValueError(
240                f"Path {traveled_path} is expected to be an object but was of type `{node['properties'][next_node]['type']}`"
241            )
242
243        if _REQUIRED not in node or not node[_REQUIRED]:
244            node[_REQUIRED] = [next_node]
245        elif next_node not in node[_REQUIRED]:
246            node[_REQUIRED].append(next_node)
247
248        traveled_path.append(next_node)
249        self._add_field_as_required(node[_PROPERTIES][next_node], path[1:], traveled_path)
250
251    def _is_leaf(self, path: List[str]) -> bool:
252        return len(path) == 0
253
254    def _remove_null_from_type(self, node: InferredSchema) -> None:
255        if isinstance(node[_TYPE], list):
256            if _NULL_TYPE in node[_TYPE]:
257                node[_TYPE].remove(_NULL_TYPE)
258            if len(node[_TYPE]) == 1:
259                node[_TYPE] = node[_TYPE][0]
260
261    def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]:
262        """
263        Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name.
264        """
265        return (
266            self._add_required_properties(
267                self._clean(self.stream_to_builder[stream_name].to_schema())
268            )
269            if stream_name in self.stream_to_builder
270            else None
271        )

This class is used to infer a JSON schema which fits all the records passed into it throughout its lifecycle via the accumulate method.

Instances of this class are stateful, meaning they build their inferred schemas from every record passed into the accumulate method.

SchemaInferrer( pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None)
94    def __init__(
95        self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None
96    ) -> None:
97        self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder)
98        self._pk = [] if pk is None else pk
99        self._cursor_field = [] if cursor_field is None else cursor_field
stream_to_builder: Dict[str, genson.schema.builder.SchemaBuilder]
def accumulate( self, record: airbyte_protocol_dataclasses.models.airbyte_protocol.AirbyteRecordMessage) -> None:
101    def accumulate(self, record: AirbyteRecordMessage) -> None:
102        """Uses the input record to add to the inferred schemas maintained by this object"""
103        self.stream_to_builder[record.stream].add_object(record.data)

Uses the input record to add to the inferred schemas maintained by this object

def get_stream_schema(self, stream_name: str) -> Optional[Dict[str, Any]]:
261    def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]:
262        """
263        Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name.
264        """
265        return (
266            self._add_required_properties(
267                self._clean(self.stream_to_builder[stream_name].to_schema())
268            )
269            if stream_name in self.stream_to_builder
270            else None
271        )

Returns the inferred JSON schema for the specified stream. Might be None if there were no records for the given stream name.

def is_cloud_environment() -> bool:
11def is_cloud_environment() -> bool:
12    """
13    Returns True if the connector is running in a cloud environment, False otherwise.
14
15    The function checks the value of the DEPLOYMENT_MODE environment variable which is set by the platform.
16    This function can be used to determine whether stricter security measures should be applied.
17    """
18    deployment_mode = os.environ.get("DEPLOYMENT_MODE", "")
19    return deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE

Returns True if the connector is running in a cloud environment, False otherwise.

The function checks the value of the DEPLOYMENT_MODE environment variable which is set by the platform. This function can be used to determine whether stricter security measures should be applied.

class PrintBuffer:
12class PrintBuffer:
13    """
14    A class to buffer print statements and flush them at a specified interval.
15
16    The PrintBuffer class is designed to capture and buffer output that would
17    normally be printed to the standard output (stdout). This can be useful for
18    scenarios where you want to minimize the number of I/O operations by grouping
19    multiple print statements together and flushing them as a single operation.
20
21    Attributes:
22        buffer (StringIO): A buffer to store the messages before flushing.
23        flush_interval (float): The time interval (in seconds) after which the buffer is flushed.
24        last_flush_time (float): The last time the buffer was flushed.
25        lock (RLock): A reentrant lock to ensure thread-safe operations.
26
27    Methods:
28        write(message: str) -> None:
29            Writes a message to the buffer and flushes if the interval has passed.
30
31        flush() -> None:
32            Flushes the buffer content to the standard output.
33
34        __enter__() -> "PrintBuffer":
35            Enters the runtime context related to this object, redirecting stdout and stderr.
36
37        __exit__(exc_type, exc_val, exc_tb) -> None:
38            Exits the runtime context and restores the original stdout and stderr.
39    """
40
41    def __init__(self, flush_interval: float = 0.1):
42        self.buffer = StringIO()
43        self.flush_interval = flush_interval
44        self.last_flush_time = time.monotonic()
45        self.lock = RLock()
46
47    def write(self, message: str) -> None:
48        with self.lock:
49            self.buffer.write(message)
50            current_time = time.monotonic()
51            if (current_time - self.last_flush_time) >= self.flush_interval:
52                self.flush()
53                self.last_flush_time = current_time
54
55    def flush(self) -> None:
56        with self.lock:
57            combined_message = self.buffer.getvalue()
58            sys.__stdout__.write(combined_message)  # type: ignore[union-attr]
59            self.buffer = StringIO()
60
61    def __enter__(self) -> "PrintBuffer":
62        self.old_stdout, self.old_stderr = sys.stdout, sys.stderr
63        # Used to disable buffering during the pytest session, because it is not compatible with capsys
64        if "pytest" not in str(type(sys.stdout)).lower():
65            sys.stdout = self
66            sys.stderr = self
67        return self
68
69    def __exit__(
70        self,
71        exc_type: Optional[BaseException],
72        exc_val: Optional[BaseException],
73        exc_tb: Optional[TracebackType],
74    ) -> None:
75        self.flush()
76        sys.stdout, sys.stderr = self.old_stdout, self.old_stderr

A class to buffer print statements and flush them at a specified interval.

The PrintBuffer class is designed to capture and buffer output that would normally be printed to the standard output (stdout). This can be useful for scenarios where you want to minimize the number of I/O operations by grouping multiple print statements together and flushing them as a single operation.

Attributes:
  • buffer (StringIO): A buffer to store the messages before flushing.
  • flush_interval (float): The time interval (in seconds) after which the buffer is flushed.
  • last_flush_time (float): The last time the buffer was flushed.
  • lock (RLock): A reentrant lock to ensure thread-safe operations.
Methods:

write(message: str) -> None: Writes a message to the buffer and flushes if the interval has passed.

flush() -> None: Flushes the buffer content to the standard output.

__enter__() -> "PrintBuffer": Enters the runtime context related to this object, redirecting stdout and stderr.

__exit__(exc_type, exc_val, exc_tb) -> None: Exits the runtime context and restores the original stdout and stderr.

PrintBuffer(flush_interval: float = 0.1)
41    def __init__(self, flush_interval: float = 0.1):
42        self.buffer = StringIO()
43        self.flush_interval = flush_interval
44        self.last_flush_time = time.monotonic()
45        self.lock = RLock()
buffer
flush_interval
last_flush_time
lock
def write(self, message: str) -> None:
47    def write(self, message: str) -> None:
48        with self.lock:
49            self.buffer.write(message)
50            current_time = time.monotonic()
51            if (current_time - self.last_flush_time) >= self.flush_interval:
52                self.flush()
53                self.last_flush_time = current_time
def flush(self) -> None:
55    def flush(self) -> None:
56        with self.lock:
57            combined_message = self.buffer.getvalue()
58            sys.__stdout__.write(combined_message)  # type: ignore[union-attr]
59            self.buffer = StringIO()