airbyte.caches.base
SQL Cache implementation.
1# Copyright (c) 2023 Airbyte, Inc., all rights reserved. 2"""SQL Cache implementation.""" 3 4from __future__ import annotations 5 6from pathlib import Path 7from typing import IO, TYPE_CHECKING, Any, ClassVar, Literal, final 8 9import pandas as pd 10import pyarrow as pa 11import pyarrow.dataset as ds 12from pydantic import Field, PrivateAttr 13from sqlalchemy import exc as sqlalchemy_exc 14from sqlalchemy import text 15 16from airbyte_protocol.models import ConfiguredAirbyteCatalog 17 18from airbyte import constants 19from airbyte._writers.base import AirbyteWriterInterface 20from airbyte.caches._catalog_backend import CatalogBackendBase, SqlCatalogBackend 21from airbyte.caches._state_backend import SqlStateBackend 22from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE, TEMP_FILE_CLEANUP 23from airbyte.datasets._sql import CachedDataset 24from airbyte.shared.catalog_providers import CatalogProvider 25from airbyte.shared.sql_processor import SqlConfig 26from airbyte.shared.state_writers import StdOutStateWriter 27 28 29if TYPE_CHECKING: 30 from collections.abc import Iterator 31 32 from airbyte._message_iterators import AirbyteMessageIterator 33 from airbyte.caches._state_backend_base import StateBackendBase 34 from airbyte.progress import ProgressTracker 35 from airbyte.shared.sql_processor import SqlProcessorBase 36 from airbyte.shared.state_providers import StateProviderBase 37 from airbyte.shared.state_writers import StateWriterBase 38 from airbyte.sources.base import Source 39 from airbyte.strategies import WriteStrategy 40 41 42class CacheBase(SqlConfig, AirbyteWriterInterface): 43 """Base configuration for a cache. 44 45 Caches inherit from the matching `SqlConfig` class, which provides the SQL config settings 46 and basic connectivity to the SQL database. 47 48 The cache is responsible for managing the state of the data synced to the cache, including the 49 stream catalog and stream state. The cache also provides the mechanism to read and write data 50 to the SQL backend specified in the `SqlConfig` class. 51 """ 52 53 cache_dir: Path = Field(default=Path(constants.DEFAULT_CACHE_ROOT)) 54 """The directory to store the cache in.""" 55 56 cleanup: bool = TEMP_FILE_CLEANUP 57 """Whether to clean up the cache after use.""" 58 59 _name: str = PrivateAttr() 60 61 _sql_processor_class: ClassVar[type[SqlProcessorBase]] 62 _read_processor: SqlProcessorBase = PrivateAttr() 63 64 _catalog_backend: CatalogBackendBase = PrivateAttr() 65 _state_backend: StateBackendBase = PrivateAttr() 66 67 paired_destination_name: ClassVar[str | None] = None 68 paired_destination_config_class: ClassVar[type | None] = None 69 70 @property 71 def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type 72 """Return a dictionary of destination configuration values.""" 73 raise NotImplementedError( 74 f"The type '{type(self).__name__}' does not define an equivalent destination " 75 "configuration." 76 ) 77 78 def __init__(self, **data: Any) -> None: # noqa: ANN401 79 """Initialize the cache and backends.""" 80 super().__init__(**data) 81 82 # Create a temporary processor to do the work of ensuring the schema exists 83 temp_processor = self._sql_processor_class( 84 sql_config=self, 85 catalog_provider=CatalogProvider(ConfiguredAirbyteCatalog(streams=[])), 86 state_writer=StdOutStateWriter(), 87 temp_dir=self.cache_dir, 88 temp_file_cleanup=self.cleanup, 89 ) 90 temp_processor._ensure_schema_exists() # noqa: SLF001 # Accessing non-public member 91 92 # Initialize the catalog and state backends 93 self._catalog_backend = SqlCatalogBackend( 94 sql_config=self, 95 table_prefix=self.table_prefix or "", 96 ) 97 self._state_backend = SqlStateBackend( 98 sql_config=self, 99 table_prefix=self.table_prefix or "", 100 ) 101 102 # Now we can create the SQL read processor 103 self._read_processor = self._sql_processor_class( 104 sql_config=self, 105 catalog_provider=self._catalog_backend.get_full_catalog_provider(), 106 state_writer=StdOutStateWriter(), # Shouldn't be needed for the read-only processor 107 temp_dir=self.cache_dir, 108 temp_file_cleanup=self.cleanup, 109 ) 110 111 @property 112 def config_hash(self) -> str | None: 113 """Return a hash of the cache configuration. 114 115 This is the same as the SQLConfig hash from the superclass. 116 """ 117 return super(SqlConfig, self).config_hash 118 119 def execute_sql(self, sql: str | list[str]) -> None: 120 """Execute one or more SQL statements against the cache's SQL backend. 121 122 If multiple SQL statements are given, they are executed in order, 123 within the same transaction. 124 125 This method is useful for creating tables, indexes, and other 126 schema objects in the cache. It does not return any results and it 127 automatically closes the connection after executing all statements. 128 129 This method is not intended for querying data. For that, use the `get_records` 130 method - or for a low-level interface, use the `get_sql_engine` method. 131 132 If any of the statements fail, the transaction is canceled and an exception 133 is raised. Most databases will rollback the transaction in this case. 134 """ 135 if isinstance(sql, str): 136 # Coerce to a list if a single string is given 137 sql = [sql] 138 139 with self.processor.get_sql_connection() as connection: 140 for sql_statement in sql: 141 connection.execute(text(sql_statement)) 142 143 @final 144 @property 145 def processor(self) -> SqlProcessorBase: 146 """Return the SQL processor instance.""" 147 return self._read_processor 148 149 def run_sql_query( 150 self, 151 sql_query: str, 152 *, 153 max_records: int | None = None, 154 ) -> list[dict[str, Any]]: 155 """Run a SQL query against the cache and return results as a list of dictionaries. 156 157 This method is designed for single DML statements like SELECT, SHOW, or DESCRIBE. 158 For DDL statements or multiple statements, use the processor directly. 159 160 Args: 161 sql_query: The SQL query to execute 162 max_records: Maximum number of records to return. If None, returns all records. 163 164 Returns: 165 List of dictionaries representing the query results 166 """ 167 # Execute the SQL within a connection context to ensure the connection stays open 168 # while we fetch the results 169 sql_text = text(sql_query) if isinstance(sql_query, str) else sql_query 170 171 with self.processor.get_sql_connection() as conn: 172 try: 173 result = conn.execute(sql_text) 174 except ( 175 sqlalchemy_exc.ProgrammingError, 176 sqlalchemy_exc.SQLAlchemyError, 177 ) as ex: 178 msg = f"Error when executing SQL:\n{sql_query}\n{type(ex).__name__}{ex!s}" 179 raise RuntimeError(msg) from ex 180 181 # Convert the result to a list of dictionaries while connection is still open 182 if result.returns_rows: 183 # Get column names 184 columns = list(result.keys()) if result.keys() else [] 185 186 # Fetch rows efficiently based on limit 187 if max_records is not None: 188 rows = result.fetchmany(max_records) 189 else: 190 rows = result.fetchall() 191 192 return [dict(zip(columns, row, strict=True)) for row in rows] 193 194 # For non-SELECT queries (INSERT, UPDATE, DELETE, etc.) 195 return [] 196 197 def get_record_processor( 198 self, 199 source_name: str, 200 catalog_provider: CatalogProvider, 201 state_writer: StateWriterBase | None = None, 202 ) -> SqlProcessorBase: 203 """Return a record processor for the specified source name and catalog. 204 205 We first register the source and its catalog with the catalog manager. Then we create a new 206 SQL processor instance with (only) the given input catalog. 207 208 For the state writer, we use a state writer which stores state in an internal SQL table. 209 """ 210 # First register the source and catalog into durable storage. This is necessary to ensure 211 # that we can later retrieve the catalog information. 212 self.register_source( 213 source_name=source_name, 214 incoming_source_catalog=catalog_provider.configured_catalog, 215 stream_names=set(catalog_provider.stream_names), 216 ) 217 218 # Next create a new SQL processor instance with the given catalog - and a state writer 219 # that writes state to the internal SQL table and associates with the given source name. 220 return self._sql_processor_class( 221 sql_config=self, 222 catalog_provider=catalog_provider, 223 state_writer=state_writer or self.get_state_writer(source_name=source_name), 224 temp_dir=self.cache_dir, 225 temp_file_cleanup=self.cleanup, 226 ) 227 228 # Read methods: 229 230 def get_records( 231 self, 232 stream_name: str, 233 ) -> CachedDataset: 234 """Uses SQLAlchemy to select all rows from the table.""" 235 return CachedDataset(self, stream_name) 236 237 def get_pandas_dataframe( 238 self, 239 stream_name: str, 240 ) -> pd.DataFrame: 241 """Return a Pandas data frame with the stream's data.""" 242 table_name = self._read_processor.get_sql_table_name(stream_name) 243 engine = self.get_sql_engine() 244 return pd.read_sql_table(table_name, engine, schema=self.schema_name) 245 246 def get_arrow_dataset( 247 self, 248 stream_name: str, 249 *, 250 max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE, 251 ) -> ds.Dataset: 252 """Return an Arrow Dataset with the stream's data.""" 253 table_name = self._read_processor.get_sql_table_name(stream_name) 254 engine = self.get_sql_engine() 255 256 # Read the table in chunks to handle large tables which does not fits in memory 257 pandas_chunks = pd.read_sql_table( 258 table_name=table_name, 259 con=engine, 260 schema=self.schema_name, 261 chunksize=max_chunk_size, 262 ) 263 264 arrow_batches_list = [] 265 arrow_schema = None 266 267 for pandas_chunk in pandas_chunks: 268 if arrow_schema is None: 269 # Initialize the schema with the first chunk 270 arrow_schema = pa.Schema.from_pandas(pandas_chunk) 271 272 # Convert each pandas chunk to an Arrow Table 273 arrow_table = pa.RecordBatch.from_pandas(pandas_chunk, schema=arrow_schema) 274 arrow_batches_list.append(arrow_table) 275 276 return ds.dataset(arrow_batches_list) 277 278 @final 279 @property 280 def streams(self) -> dict[str, CachedDataset]: 281 """Return a temporary table name.""" 282 result = {} 283 stream_names = set(self._catalog_backend.stream_names) 284 285 for stream_name in stream_names: 286 result[stream_name] = CachedDataset(self, stream_name) 287 288 return result 289 290 @final 291 def __len__(self) -> int: 292 """Gets the number of streams.""" 293 return len(self._catalog_backend.stream_names) 294 295 @final 296 def __bool__(self) -> bool: 297 """Always True. 298 299 This is needed so that caches with zero streams are not falsey (None-like). 300 """ 301 return True 302 303 def get_state_provider( 304 self, 305 source_name: str, 306 *, 307 refresh: bool = True, 308 destination_name: str | None = None, 309 ) -> StateProviderBase: 310 """Return a state provider for the specified source name.""" 311 return self._state_backend.get_state_provider( 312 source_name=source_name, 313 table_prefix=self.table_prefix or "", 314 refresh=refresh, 315 destination_name=destination_name, 316 ) 317 318 def get_state_writer( 319 self, 320 source_name: str, 321 destination_name: str | None = None, 322 ) -> StateWriterBase: 323 """Return a state writer for the specified source name. 324 325 If syncing to the cache, `destination_name` should be `None`. 326 If syncing to a destination, `destination_name` should be the destination name. 327 """ 328 return self._state_backend.get_state_writer( 329 source_name=source_name, 330 destination_name=destination_name, 331 ) 332 333 def register_source( 334 self, 335 source_name: str, 336 incoming_source_catalog: ConfiguredAirbyteCatalog, 337 stream_names: set[str], 338 ) -> None: 339 """Register the source name and catalog.""" 340 self._catalog_backend.register_source( 341 source_name=source_name, 342 incoming_source_catalog=incoming_source_catalog, 343 incoming_stream_names=stream_names, 344 ) 345 346 def create_source_tables( 347 self, 348 source: Source, 349 streams: Literal["*"] | list[str] | None = None, 350 ) -> None: 351 """Create tables in the cache for the provided source if they do not exist already. 352 353 Tables are created based upon the Source's catalog. 354 355 Args: 356 source: The source to create tables for. 357 streams: Stream names to create tables for. If None, use the Source's selected_streams 358 or "*" if neither is set. If "*", all available streams will be used. 359 """ 360 if streams is None: 361 streams = source.get_selected_streams() or "*" 362 363 catalog_provider = CatalogProvider(source.get_configured_catalog(streams=streams)) 364 365 # Register the incoming source catalog 366 self.register_source( 367 source_name=source.name, 368 incoming_source_catalog=catalog_provider.configured_catalog, 369 stream_names=set(catalog_provider.stream_names), 370 ) 371 372 # Ensure schema exists 373 self.processor._ensure_schema_exists() # noqa: SLF001 # Accessing non-public member 374 375 # Create tables for each stream if they don't exist 376 for stream_name in catalog_provider.stream_names: 377 self.processor._ensure_final_table_exists( # noqa: SLF001 378 stream_name=stream_name, 379 create_if_missing=True, 380 ) 381 382 def __getitem__(self, stream: str) -> CachedDataset: 383 """Return a dataset by stream name.""" 384 return self.streams[stream] 385 386 def __contains__(self, stream: str) -> bool: 387 """Return whether a stream is in the cache.""" 388 return stream in (self._catalog_backend.stream_names) 389 390 def __iter__( # type: ignore [override] # Overriding Pydantic model method 391 self, 392 ) -> Iterator[tuple[str, Any]]: 393 """Iterate over the streams in the cache.""" 394 return ((name, dataset) for name, dataset in self.streams.items()) 395 396 def _write_airbyte_message_stream( 397 self, 398 stdin: IO[str] | AirbyteMessageIterator, 399 *, 400 catalog_provider: CatalogProvider, 401 write_strategy: WriteStrategy, 402 state_writer: StateWriterBase | None = None, 403 progress_tracker: ProgressTracker, 404 ) -> None: 405 """Read from the connector and write to the cache.""" 406 cache_processor = self.get_record_processor( 407 source_name=self.name, 408 catalog_provider=catalog_provider, 409 state_writer=state_writer, 410 ) 411 cache_processor.process_airbyte_messages( 412 messages=stdin, 413 write_strategy=write_strategy, 414 progress_tracker=progress_tracker, 415 ) 416 progress_tracker.log_cache_processing_complete()
43class CacheBase(SqlConfig, AirbyteWriterInterface): 44 """Base configuration for a cache. 45 46 Caches inherit from the matching `SqlConfig` class, which provides the SQL config settings 47 and basic connectivity to the SQL database. 48 49 The cache is responsible for managing the state of the data synced to the cache, including the 50 stream catalog and stream state. The cache also provides the mechanism to read and write data 51 to the SQL backend specified in the `SqlConfig` class. 52 """ 53 54 cache_dir: Path = Field(default=Path(constants.DEFAULT_CACHE_ROOT)) 55 """The directory to store the cache in.""" 56 57 cleanup: bool = TEMP_FILE_CLEANUP 58 """Whether to clean up the cache after use.""" 59 60 _name: str = PrivateAttr() 61 62 _sql_processor_class: ClassVar[type[SqlProcessorBase]] 63 _read_processor: SqlProcessorBase = PrivateAttr() 64 65 _catalog_backend: CatalogBackendBase = PrivateAttr() 66 _state_backend: StateBackendBase = PrivateAttr() 67 68 paired_destination_name: ClassVar[str | None] = None 69 paired_destination_config_class: ClassVar[type | None] = None 70 71 @property 72 def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type 73 """Return a dictionary of destination configuration values.""" 74 raise NotImplementedError( 75 f"The type '{type(self).__name__}' does not define an equivalent destination " 76 "configuration." 77 ) 78 79 def __init__(self, **data: Any) -> None: # noqa: ANN401 80 """Initialize the cache and backends.""" 81 super().__init__(**data) 82 83 # Create a temporary processor to do the work of ensuring the schema exists 84 temp_processor = self._sql_processor_class( 85 sql_config=self, 86 catalog_provider=CatalogProvider(ConfiguredAirbyteCatalog(streams=[])), 87 state_writer=StdOutStateWriter(), 88 temp_dir=self.cache_dir, 89 temp_file_cleanup=self.cleanup, 90 ) 91 temp_processor._ensure_schema_exists() # noqa: SLF001 # Accessing non-public member 92 93 # Initialize the catalog and state backends 94 self._catalog_backend = SqlCatalogBackend( 95 sql_config=self, 96 table_prefix=self.table_prefix or "", 97 ) 98 self._state_backend = SqlStateBackend( 99 sql_config=self, 100 table_prefix=self.table_prefix or "", 101 ) 102 103 # Now we can create the SQL read processor 104 self._read_processor = self._sql_processor_class( 105 sql_config=self, 106 catalog_provider=self._catalog_backend.get_full_catalog_provider(), 107 state_writer=StdOutStateWriter(), # Shouldn't be needed for the read-only processor 108 temp_dir=self.cache_dir, 109 temp_file_cleanup=self.cleanup, 110 ) 111 112 @property 113 def config_hash(self) -> str | None: 114 """Return a hash of the cache configuration. 115 116 This is the same as the SQLConfig hash from the superclass. 117 """ 118 return super(SqlConfig, self).config_hash 119 120 def execute_sql(self, sql: str | list[str]) -> None: 121 """Execute one or more SQL statements against the cache's SQL backend. 122 123 If multiple SQL statements are given, they are executed in order, 124 within the same transaction. 125 126 This method is useful for creating tables, indexes, and other 127 schema objects in the cache. It does not return any results and it 128 automatically closes the connection after executing all statements. 129 130 This method is not intended for querying data. For that, use the `get_records` 131 method - or for a low-level interface, use the `get_sql_engine` method. 132 133 If any of the statements fail, the transaction is canceled and an exception 134 is raised. Most databases will rollback the transaction in this case. 135 """ 136 if isinstance(sql, str): 137 # Coerce to a list if a single string is given 138 sql = [sql] 139 140 with self.processor.get_sql_connection() as connection: 141 for sql_statement in sql: 142 connection.execute(text(sql_statement)) 143 144 @final 145 @property 146 def processor(self) -> SqlProcessorBase: 147 """Return the SQL processor instance.""" 148 return self._read_processor 149 150 def run_sql_query( 151 self, 152 sql_query: str, 153 *, 154 max_records: int | None = None, 155 ) -> list[dict[str, Any]]: 156 """Run a SQL query against the cache and return results as a list of dictionaries. 157 158 This method is designed for single DML statements like SELECT, SHOW, or DESCRIBE. 159 For DDL statements or multiple statements, use the processor directly. 160 161 Args: 162 sql_query: The SQL query to execute 163 max_records: Maximum number of records to return. If None, returns all records. 164 165 Returns: 166 List of dictionaries representing the query results 167 """ 168 # Execute the SQL within a connection context to ensure the connection stays open 169 # while we fetch the results 170 sql_text = text(sql_query) if isinstance(sql_query, str) else sql_query 171 172 with self.processor.get_sql_connection() as conn: 173 try: 174 result = conn.execute(sql_text) 175 except ( 176 sqlalchemy_exc.ProgrammingError, 177 sqlalchemy_exc.SQLAlchemyError, 178 ) as ex: 179 msg = f"Error when executing SQL:\n{sql_query}\n{type(ex).__name__}{ex!s}" 180 raise RuntimeError(msg) from ex 181 182 # Convert the result to a list of dictionaries while connection is still open 183 if result.returns_rows: 184 # Get column names 185 columns = list(result.keys()) if result.keys() else [] 186 187 # Fetch rows efficiently based on limit 188 if max_records is not None: 189 rows = result.fetchmany(max_records) 190 else: 191 rows = result.fetchall() 192 193 return [dict(zip(columns, row, strict=True)) for row in rows] 194 195 # For non-SELECT queries (INSERT, UPDATE, DELETE, etc.) 196 return [] 197 198 def get_record_processor( 199 self, 200 source_name: str, 201 catalog_provider: CatalogProvider, 202 state_writer: StateWriterBase | None = None, 203 ) -> SqlProcessorBase: 204 """Return a record processor for the specified source name and catalog. 205 206 We first register the source and its catalog with the catalog manager. Then we create a new 207 SQL processor instance with (only) the given input catalog. 208 209 For the state writer, we use a state writer which stores state in an internal SQL table. 210 """ 211 # First register the source and catalog into durable storage. This is necessary to ensure 212 # that we can later retrieve the catalog information. 213 self.register_source( 214 source_name=source_name, 215 incoming_source_catalog=catalog_provider.configured_catalog, 216 stream_names=set(catalog_provider.stream_names), 217 ) 218 219 # Next create a new SQL processor instance with the given catalog - and a state writer 220 # that writes state to the internal SQL table and associates with the given source name. 221 return self._sql_processor_class( 222 sql_config=self, 223 catalog_provider=catalog_provider, 224 state_writer=state_writer or self.get_state_writer(source_name=source_name), 225 temp_dir=self.cache_dir, 226 temp_file_cleanup=self.cleanup, 227 ) 228 229 # Read methods: 230 231 def get_records( 232 self, 233 stream_name: str, 234 ) -> CachedDataset: 235 """Uses SQLAlchemy to select all rows from the table.""" 236 return CachedDataset(self, stream_name) 237 238 def get_pandas_dataframe( 239 self, 240 stream_name: str, 241 ) -> pd.DataFrame: 242 """Return a Pandas data frame with the stream's data.""" 243 table_name = self._read_processor.get_sql_table_name(stream_name) 244 engine = self.get_sql_engine() 245 return pd.read_sql_table(table_name, engine, schema=self.schema_name) 246 247 def get_arrow_dataset( 248 self, 249 stream_name: str, 250 *, 251 max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE, 252 ) -> ds.Dataset: 253 """Return an Arrow Dataset with the stream's data.""" 254 table_name = self._read_processor.get_sql_table_name(stream_name) 255 engine = self.get_sql_engine() 256 257 # Read the table in chunks to handle large tables which does not fits in memory 258 pandas_chunks = pd.read_sql_table( 259 table_name=table_name, 260 con=engine, 261 schema=self.schema_name, 262 chunksize=max_chunk_size, 263 ) 264 265 arrow_batches_list = [] 266 arrow_schema = None 267 268 for pandas_chunk in pandas_chunks: 269 if arrow_schema is None: 270 # Initialize the schema with the first chunk 271 arrow_schema = pa.Schema.from_pandas(pandas_chunk) 272 273 # Convert each pandas chunk to an Arrow Table 274 arrow_table = pa.RecordBatch.from_pandas(pandas_chunk, schema=arrow_schema) 275 arrow_batches_list.append(arrow_table) 276 277 return ds.dataset(arrow_batches_list) 278 279 @final 280 @property 281 def streams(self) -> dict[str, CachedDataset]: 282 """Return a temporary table name.""" 283 result = {} 284 stream_names = set(self._catalog_backend.stream_names) 285 286 for stream_name in stream_names: 287 result[stream_name] = CachedDataset(self, stream_name) 288 289 return result 290 291 @final 292 def __len__(self) -> int: 293 """Gets the number of streams.""" 294 return len(self._catalog_backend.stream_names) 295 296 @final 297 def __bool__(self) -> bool: 298 """Always True. 299 300 This is needed so that caches with zero streams are not falsey (None-like). 301 """ 302 return True 303 304 def get_state_provider( 305 self, 306 source_name: str, 307 *, 308 refresh: bool = True, 309 destination_name: str | None = None, 310 ) -> StateProviderBase: 311 """Return a state provider for the specified source name.""" 312 return self._state_backend.get_state_provider( 313 source_name=source_name, 314 table_prefix=self.table_prefix or "", 315 refresh=refresh, 316 destination_name=destination_name, 317 ) 318 319 def get_state_writer( 320 self, 321 source_name: str, 322 destination_name: str | None = None, 323 ) -> StateWriterBase: 324 """Return a state writer for the specified source name. 325 326 If syncing to the cache, `destination_name` should be `None`. 327 If syncing to a destination, `destination_name` should be the destination name. 328 """ 329 return self._state_backend.get_state_writer( 330 source_name=source_name, 331 destination_name=destination_name, 332 ) 333 334 def register_source( 335 self, 336 source_name: str, 337 incoming_source_catalog: ConfiguredAirbyteCatalog, 338 stream_names: set[str], 339 ) -> None: 340 """Register the source name and catalog.""" 341 self._catalog_backend.register_source( 342 source_name=source_name, 343 incoming_source_catalog=incoming_source_catalog, 344 incoming_stream_names=stream_names, 345 ) 346 347 def create_source_tables( 348 self, 349 source: Source, 350 streams: Literal["*"] | list[str] | None = None, 351 ) -> None: 352 """Create tables in the cache for the provided source if they do not exist already. 353 354 Tables are created based upon the Source's catalog. 355 356 Args: 357 source: The source to create tables for. 358 streams: Stream names to create tables for. If None, use the Source's selected_streams 359 or "*" if neither is set. If "*", all available streams will be used. 360 """ 361 if streams is None: 362 streams = source.get_selected_streams() or "*" 363 364 catalog_provider = CatalogProvider(source.get_configured_catalog(streams=streams)) 365 366 # Register the incoming source catalog 367 self.register_source( 368 source_name=source.name, 369 incoming_source_catalog=catalog_provider.configured_catalog, 370 stream_names=set(catalog_provider.stream_names), 371 ) 372 373 # Ensure schema exists 374 self.processor._ensure_schema_exists() # noqa: SLF001 # Accessing non-public member 375 376 # Create tables for each stream if they don't exist 377 for stream_name in catalog_provider.stream_names: 378 self.processor._ensure_final_table_exists( # noqa: SLF001 379 stream_name=stream_name, 380 create_if_missing=True, 381 ) 382 383 def __getitem__(self, stream: str) -> CachedDataset: 384 """Return a dataset by stream name.""" 385 return self.streams[stream] 386 387 def __contains__(self, stream: str) -> bool: 388 """Return whether a stream is in the cache.""" 389 return stream in (self._catalog_backend.stream_names) 390 391 def __iter__( # type: ignore [override] # Overriding Pydantic model method 392 self, 393 ) -> Iterator[tuple[str, Any]]: 394 """Iterate over the streams in the cache.""" 395 return ((name, dataset) for name, dataset in self.streams.items()) 396 397 def _write_airbyte_message_stream( 398 self, 399 stdin: IO[str] | AirbyteMessageIterator, 400 *, 401 catalog_provider: CatalogProvider, 402 write_strategy: WriteStrategy, 403 state_writer: StateWriterBase | None = None, 404 progress_tracker: ProgressTracker, 405 ) -> None: 406 """Read from the connector and write to the cache.""" 407 cache_processor = self.get_record_processor( 408 source_name=self.name, 409 catalog_provider=catalog_provider, 410 state_writer=state_writer, 411 ) 412 cache_processor.process_airbyte_messages( 413 messages=stdin, 414 write_strategy=write_strategy, 415 progress_tracker=progress_tracker, 416 ) 417 progress_tracker.log_cache_processing_complete()
Base configuration for a cache.
Caches inherit from the matching SqlConfig
class, which provides the SQL config settings
and basic connectivity to the SQL database.
The cache is responsible for managing the state of the data synced to the cache, including the
stream catalog and stream state. The cache also provides the mechanism to read and write data
to the SQL backend specified in the SqlConfig
class.
79 def __init__(self, **data: Any) -> None: # noqa: ANN401 80 """Initialize the cache and backends.""" 81 super().__init__(**data) 82 83 # Create a temporary processor to do the work of ensuring the schema exists 84 temp_processor = self._sql_processor_class( 85 sql_config=self, 86 catalog_provider=CatalogProvider(ConfiguredAirbyteCatalog(streams=[])), 87 state_writer=StdOutStateWriter(), 88 temp_dir=self.cache_dir, 89 temp_file_cleanup=self.cleanup, 90 ) 91 temp_processor._ensure_schema_exists() # noqa: SLF001 # Accessing non-public member 92 93 # Initialize the catalog and state backends 94 self._catalog_backend = SqlCatalogBackend( 95 sql_config=self, 96 table_prefix=self.table_prefix or "", 97 ) 98 self._state_backend = SqlStateBackend( 99 sql_config=self, 100 table_prefix=self.table_prefix or "", 101 ) 102 103 # Now we can create the SQL read processor 104 self._read_processor = self._sql_processor_class( 105 sql_config=self, 106 catalog_provider=self._catalog_backend.get_full_catalog_provider(), 107 state_writer=StdOutStateWriter(), # Shouldn't be needed for the read-only processor 108 temp_dir=self.cache_dir, 109 temp_file_cleanup=self.cleanup, 110 )
Initialize the cache and backends.
71 @property 72 def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type 73 """Return a dictionary of destination configuration values.""" 74 raise NotImplementedError( 75 f"The type '{type(self).__name__}' does not define an equivalent destination " 76 "configuration." 77 )
Return a dictionary of destination configuration values.
112 @property 113 def config_hash(self) -> str | None: 114 """Return a hash of the cache configuration. 115 116 This is the same as the SQLConfig hash from the superclass. 117 """ 118 return super(SqlConfig, self).config_hash
Return a hash of the cache configuration.
This is the same as the SQLConfig hash from the superclass.
120 def execute_sql(self, sql: str | list[str]) -> None: 121 """Execute one or more SQL statements against the cache's SQL backend. 122 123 If multiple SQL statements are given, they are executed in order, 124 within the same transaction. 125 126 This method is useful for creating tables, indexes, and other 127 schema objects in the cache. It does not return any results and it 128 automatically closes the connection after executing all statements. 129 130 This method is not intended for querying data. For that, use the `get_records` 131 method - or for a low-level interface, use the `get_sql_engine` method. 132 133 If any of the statements fail, the transaction is canceled and an exception 134 is raised. Most databases will rollback the transaction in this case. 135 """ 136 if isinstance(sql, str): 137 # Coerce to a list if a single string is given 138 sql = [sql] 139 140 with self.processor.get_sql_connection() as connection: 141 for sql_statement in sql: 142 connection.execute(text(sql_statement))
Execute one or more SQL statements against the cache's SQL backend.
If multiple SQL statements are given, they are executed in order, within the same transaction.
This method is useful for creating tables, indexes, and other schema objects in the cache. It does not return any results and it automatically closes the connection after executing all statements.
This method is not intended for querying data. For that, use the get_records
method - or for a low-level interface, use the get_sql_engine
method.
If any of the statements fail, the transaction is canceled and an exception is raised. Most databases will rollback the transaction in this case.
144 @final 145 @property 146 def processor(self) -> SqlProcessorBase: 147 """Return the SQL processor instance.""" 148 return self._read_processor
Return the SQL processor instance.
150 def run_sql_query( 151 self, 152 sql_query: str, 153 *, 154 max_records: int | None = None, 155 ) -> list[dict[str, Any]]: 156 """Run a SQL query against the cache and return results as a list of dictionaries. 157 158 This method is designed for single DML statements like SELECT, SHOW, or DESCRIBE. 159 For DDL statements or multiple statements, use the processor directly. 160 161 Args: 162 sql_query: The SQL query to execute 163 max_records: Maximum number of records to return. If None, returns all records. 164 165 Returns: 166 List of dictionaries representing the query results 167 """ 168 # Execute the SQL within a connection context to ensure the connection stays open 169 # while we fetch the results 170 sql_text = text(sql_query) if isinstance(sql_query, str) else sql_query 171 172 with self.processor.get_sql_connection() as conn: 173 try: 174 result = conn.execute(sql_text) 175 except ( 176 sqlalchemy_exc.ProgrammingError, 177 sqlalchemy_exc.SQLAlchemyError, 178 ) as ex: 179 msg = f"Error when executing SQL:\n{sql_query}\n{type(ex).__name__}{ex!s}" 180 raise RuntimeError(msg) from ex 181 182 # Convert the result to a list of dictionaries while connection is still open 183 if result.returns_rows: 184 # Get column names 185 columns = list(result.keys()) if result.keys() else [] 186 187 # Fetch rows efficiently based on limit 188 if max_records is not None: 189 rows = result.fetchmany(max_records) 190 else: 191 rows = result.fetchall() 192 193 return [dict(zip(columns, row, strict=True)) for row in rows] 194 195 # For non-SELECT queries (INSERT, UPDATE, DELETE, etc.) 196 return []
Run a SQL query against the cache and return results as a list of dictionaries.
This method is designed for single DML statements like SELECT, SHOW, or DESCRIBE. For DDL statements or multiple statements, use the processor directly.
Arguments:
- sql_query: The SQL query to execute
- max_records: Maximum number of records to return. If None, returns all records.
Returns:
List of dictionaries representing the query results
198 def get_record_processor( 199 self, 200 source_name: str, 201 catalog_provider: CatalogProvider, 202 state_writer: StateWriterBase | None = None, 203 ) -> SqlProcessorBase: 204 """Return a record processor for the specified source name and catalog. 205 206 We first register the source and its catalog with the catalog manager. Then we create a new 207 SQL processor instance with (only) the given input catalog. 208 209 For the state writer, we use a state writer which stores state in an internal SQL table. 210 """ 211 # First register the source and catalog into durable storage. This is necessary to ensure 212 # that we can later retrieve the catalog information. 213 self.register_source( 214 source_name=source_name, 215 incoming_source_catalog=catalog_provider.configured_catalog, 216 stream_names=set(catalog_provider.stream_names), 217 ) 218 219 # Next create a new SQL processor instance with the given catalog - and a state writer 220 # that writes state to the internal SQL table and associates with the given source name. 221 return self._sql_processor_class( 222 sql_config=self, 223 catalog_provider=catalog_provider, 224 state_writer=state_writer or self.get_state_writer(source_name=source_name), 225 temp_dir=self.cache_dir, 226 temp_file_cleanup=self.cleanup, 227 )
Return a record processor for the specified source name and catalog.
We first register the source and its catalog with the catalog manager. Then we create a new SQL processor instance with (only) the given input catalog.
For the state writer, we use a state writer which stores state in an internal SQL table.
231 def get_records( 232 self, 233 stream_name: str, 234 ) -> CachedDataset: 235 """Uses SQLAlchemy to select all rows from the table.""" 236 return CachedDataset(self, stream_name)
Uses SQLAlchemy to select all rows from the table.
238 def get_pandas_dataframe( 239 self, 240 stream_name: str, 241 ) -> pd.DataFrame: 242 """Return a Pandas data frame with the stream's data.""" 243 table_name = self._read_processor.get_sql_table_name(stream_name) 244 engine = self.get_sql_engine() 245 return pd.read_sql_table(table_name, engine, schema=self.schema_name)
Return a Pandas data frame with the stream's data.
247 def get_arrow_dataset( 248 self, 249 stream_name: str, 250 *, 251 max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE, 252 ) -> ds.Dataset: 253 """Return an Arrow Dataset with the stream's data.""" 254 table_name = self._read_processor.get_sql_table_name(stream_name) 255 engine = self.get_sql_engine() 256 257 # Read the table in chunks to handle large tables which does not fits in memory 258 pandas_chunks = pd.read_sql_table( 259 table_name=table_name, 260 con=engine, 261 schema=self.schema_name, 262 chunksize=max_chunk_size, 263 ) 264 265 arrow_batches_list = [] 266 arrow_schema = None 267 268 for pandas_chunk in pandas_chunks: 269 if arrow_schema is None: 270 # Initialize the schema with the first chunk 271 arrow_schema = pa.Schema.from_pandas(pandas_chunk) 272 273 # Convert each pandas chunk to an Arrow Table 274 arrow_table = pa.RecordBatch.from_pandas(pandas_chunk, schema=arrow_schema) 275 arrow_batches_list.append(arrow_table) 276 277 return ds.dataset(arrow_batches_list)
Return an Arrow Dataset with the stream's data.
279 @final 280 @property 281 def streams(self) -> dict[str, CachedDataset]: 282 """Return a temporary table name.""" 283 result = {} 284 stream_names = set(self._catalog_backend.stream_names) 285 286 for stream_name in stream_names: 287 result[stream_name] = CachedDataset(self, stream_name) 288 289 return result
Return a temporary table name.
304 def get_state_provider( 305 self, 306 source_name: str, 307 *, 308 refresh: bool = True, 309 destination_name: str | None = None, 310 ) -> StateProviderBase: 311 """Return a state provider for the specified source name.""" 312 return self._state_backend.get_state_provider( 313 source_name=source_name, 314 table_prefix=self.table_prefix or "", 315 refresh=refresh, 316 destination_name=destination_name, 317 )
Return a state provider for the specified source name.
319 def get_state_writer( 320 self, 321 source_name: str, 322 destination_name: str | None = None, 323 ) -> StateWriterBase: 324 """Return a state writer for the specified source name. 325 326 If syncing to the cache, `destination_name` should be `None`. 327 If syncing to a destination, `destination_name` should be the destination name. 328 """ 329 return self._state_backend.get_state_writer( 330 source_name=source_name, 331 destination_name=destination_name, 332 )
Return a state writer for the specified source name.
If syncing to the cache, destination_name
should be None
.
If syncing to a destination, destination_name
should be the destination name.
334 def register_source( 335 self, 336 source_name: str, 337 incoming_source_catalog: ConfiguredAirbyteCatalog, 338 stream_names: set[str], 339 ) -> None: 340 """Register the source name and catalog.""" 341 self._catalog_backend.register_source( 342 source_name=source_name, 343 incoming_source_catalog=incoming_source_catalog, 344 incoming_stream_names=stream_names, 345 )
Register the source name and catalog.
347 def create_source_tables( 348 self, 349 source: Source, 350 streams: Literal["*"] | list[str] | None = None, 351 ) -> None: 352 """Create tables in the cache for the provided source if they do not exist already. 353 354 Tables are created based upon the Source's catalog. 355 356 Args: 357 source: The source to create tables for. 358 streams: Stream names to create tables for. If None, use the Source's selected_streams 359 or "*" if neither is set. If "*", all available streams will be used. 360 """ 361 if streams is None: 362 streams = source.get_selected_streams() or "*" 363 364 catalog_provider = CatalogProvider(source.get_configured_catalog(streams=streams)) 365 366 # Register the incoming source catalog 367 self.register_source( 368 source_name=source.name, 369 incoming_source_catalog=catalog_provider.configured_catalog, 370 stream_names=set(catalog_provider.stream_names), 371 ) 372 373 # Ensure schema exists 374 self.processor._ensure_schema_exists() # noqa: SLF001 # Accessing non-public member 375 376 # Create tables for each stream if they don't exist 377 for stream_name in catalog_provider.stream_names: 378 self.processor._ensure_final_table_exists( # noqa: SLF001 379 stream_name=stream_name, 380 create_if_missing=True, 381 )
Create tables in the cache for the provided source if they do not exist already.
Tables are created based upon the Source's catalog.
Arguments:
- source: The source to create tables for.
- streams: Stream names to create tables for. If None, use the Source's selected_streams or "" if neither is set. If "", all available streams will be used.
Configuration for the model, should be a dictionary conforming to [ConfigDict
][pydantic.config.ConfigDict].
328def init_private_attributes(self: BaseModel, context: Any, /) -> None: 329 """This function is meant to behave like a BaseModel method to initialise private attributes. 330 331 It takes context as an argument since that's what pydantic-core passes when calling it. 332 333 Args: 334 self: The BaseModel instance. 335 context: The context. 336 """ 337 if getattr(self, '__pydantic_private__', None) is None: 338 pydantic_private = {} 339 for name, private_attr in self.__private_attributes__.items(): 340 default = private_attr.get_default() 341 if default is not PydanticUndefined: 342 pydantic_private[name] = default 343 object_setattr(self, '__pydantic_private__', pydantic_private)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Arguments:
- self: The BaseModel instance.
- context: The context.
Inherited Members
- airbyte.shared.sql_processor.SqlConfig
- schema_name
- table_prefix
- get_sql_alchemy_url
- get_database_name
- get_create_table_extra_clauses
- get_sql_alchemy_connect_args
- get_sql_engine
- get_vendor_client
- pydantic.main.BaseModel
- model_extra
- model_fields_set
- model_construct
- model_copy
- model_dump
- model_dump_json
- model_json_schema
- model_parametrized_name
- model_rebuild
- model_validate
- model_validate_json
- model_validate_strings
- dict
- json
- parse_obj
- parse_raw
- parse_file
- from_orm
- construct
- copy
- schema
- schema_json
- validate
- update_forward_refs
- model_fields
- model_computed_fields
- airbyte._writers.base.AirbyteWriterInterface
- name