From 90d5b8e220db51465e4dbac8df6e4bd4941c9ba6 Mon Sep 17 00:00:00 2001 From: Steve Kowalik Date: Tue, 26 Sep 2023 11:59:39 +1000 Subject: [PATCH] Migrate to SQLAlchemy 2 https://github.com/wireservice/agate-sql/pull/40 Remove the upper bound on SQLAlchemy by converting the code idioms in use to support both SQLAlchemy 1.4 and SQLAlchemy 2, and only setting a lower bound SQLAlchemy of >= 1.4. Closes #39 diff --git a/agatesql/table.py b/agatesql/table.py index b141937..e4efe91 100644 --- a/agatesql/table.py +++ b/agatesql/table.py @@ -82,2 +82,2 @@ def from_sql(cls, connection_or_string, table_name): - metadata = MetaData(connection) - sql_table = Table(table_name, metadata, autoload=True, autoload_with=connection) + metadata = MetaData() + sql_table = Table(table_name, metadata, autoload_with=connection) @@ -113 +113 @@ def from_sql(cls, connection_or_string, table_name): - s = select([sql_table]) + s = select(sql_table) @@ -182 +182 @@ def make_sql_table(table, table_name, dialect=None, db_schema=None, constraints= - metadata = MetaData(connection) + metadata = MetaData() @@ -276,2 +276,3 @@ def to_sql(self, connection_or_string, table_name, overwrite=False, - if overwrite: - sql_table.drop(checkfirst=True) + with connection.begin(): + if overwrite: + sql_table.drop(bind=connection, checkfirst=True) @@ -279 +280 @@ def to_sql(self, connection_or_string, table_name, overwrite=False, - sql_table.create(checkfirst=create_if_not_exists) + sql_table.create(bind=connection, checkfirst=create_if_not_exists) @@ -282,13 +283,14 @@ def to_sql(self, connection_or_string, table_name, overwrite=False, - insert = sql_table.insert() - for prefix in prefixes: - insert = insert.prefix_with(prefix) - if chunk_size is None: - connection.execute(insert, [dict(zip(self.column_names, row)) for row in self.rows]) - else: - number_of_rows = len(self.rows) - for index in range((number_of_rows - 1) // chunk_size + 1): - end_index = (index + 1) * chunk_size - if end_index > number_of_rows: - end_index = number_of_rows - connection.execute(insert, [dict(zip(self.column_names, row)) for row in - self.rows[index * chunk_size:end_index]]) + with connection.begin(): + insert = sql_table.insert() + for prefix in prefixes: + insert = insert.prefix_with(prefix) + if chunk_size is None: + connection.execute(insert, [dict(zip(self.column_names, row)) for row in self.rows]) + else: + number_of_rows = len(self.rows) + for index in range((number_of_rows - 1) // chunk_size + 1): + end_index = (index + 1) * chunk_size + if end_index > number_of_rows: + end_index = number_of_rows + connection.execute(insert, [dict(zip(self.column_names, row)) for row in + self.rows[index * chunk_size:end_index]]) @@ -354 +356 @@ def sql_query(self, query, table_name='agate'): - rows = connection.execute(q) + rows = connection.exec_driver_sql(q) diff --git a/setup.py b/setup.py index 3905203..7257399 100644 --- a/setup.py +++ b/setup.py @@ -37 +37 @@ setup( - 'sqlalchemy<2', + 'sqlalchemy>=1.4',