diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 68bf5db..92d375e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,6 +42,9 @@ jobs: run: | sudo apt-get -qq update pip install -e ".[dev]" + + - uses: mxschmitt/action-tmate@v3 + - name: Run SQLite tests env: DATABASE_URL: "sqlite:///:memory:" diff --git a/dataset/database.py b/dataset/database.py index d8a07ad..4bc31fd 100644 --- a/dataset/database.py +++ b/dataset/database.py @@ -106,7 +106,7 @@ class Database(object): @property def metadata(self): """Return a SQLAlchemy schema cache object.""" - return MetaData(schema=self.schema, bind=self.executable) + return MetaData(schema=self.schema) @property def in_transaction(self): @@ -127,6 +127,8 @@ class Database(object): """ if not hasattr(self.local, "tx"): self.local.tx = [] + if self.executable.in_transaction(): + self.executable.commit() self.local.tx.append(self.executable.begin()) def commit(self): diff --git a/dataset/table.py b/dataset/table.py index 08b806b..2f27060 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -116,7 +116,12 @@ class Table(object): Returns the inserted row's primary key. """ row = self._sync_columns(row, ensure, types=types) - res = self.db.executable.execute(self.table.insert(row)) + res = self.db.executable.execute(self.table.insert(), row) + # SQLAlchemy 2.0.0b1 removes auto commit + if hasattr(self.db.local, "tx") and self.db.local.tx: + pass + else: + self.db.executable.commit() if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] return True @@ -181,7 +186,8 @@ class Table(object): # Insert when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: chunk = pad_chunk_columns(chunk, columns) - self.table.insert().execute(chunk) + with self.db.engine.begin() as conn: + conn.execute(self.table.insert(), chunk) chunk = [] def update(self, row, keys, ensure=None, types=None, return_count=False): @@ -206,7 +212,7 @@ class Table(object): clause = self._args_to_clause(args) if not len(row): return self.count(clause) - stmt = self.table.update(whereclause=clause, values=row) + stmt = self.table.update().where(clause).values(row) rp = self.db.executable.execute(stmt) if rp.supports_sane_rowcount(): return rp.rowcount @@ -241,10 +247,9 @@ class Table(object): # Update when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys] - stmt = self.table.update( - whereclause=and_(True, *cl), - values={col: bindparam(col, required=False) for col in columns}, - ) + stmt = self.table.update()\ + .where(and_(True, *cl))\ + .values({col: bindparam(col, required=False) for col in columns}) self.db.executable.execute(stmt, chunk) chunk = [] @@ -293,7 +298,7 @@ class Table(object): if not self.exists: return False clause = self._args_to_clause(filters, clauses=clauses) - stmt = self.table.delete(whereclause=clause) + stmt = self.table.delete().where(clause) rp = self.db.executable.execute(stmt) return rp.rowcount > 0 @@ -303,7 +308,7 @@ class Table(object): self._columns = None try: self._table = SQLATable( - self.name, self.db.metadata, schema=self.db.schema, autoload=True + self.name, self.db.metadata, schema=self.db.schema, autoload_with=self.db.engine, ) except NoSuchTableError: self._table = None @@ -625,7 +630,7 @@ class Table(object): order_by = self._args_to_order_by(order_by) args = self._args_to_clause(kwargs, clauses=_clauses) - query = self.table.select(whereclause=args, limit=_limit, offset=_offset) + query = self.table.select().where(args).limit(_limit).offset(_offset) if len(order_by): query = query.order_by(*order_by) @@ -666,7 +671,7 @@ class Table(object): return 0 args = self._args_to_clause(kwargs, clauses=_clauses) - query = select([func.count()], whereclause=args) + query = select(func.count()).where(args) query = query.select_from(self.table) rp = self.db.executable.execute(query) return rp.fetchone()[0] @@ -703,12 +708,10 @@ class Table(object): if not len(columns): return iter([]) - q = expression.select( - columns, - distinct=True, - whereclause=clause, - order_by=[c.asc() for c in columns], - ) + q = expression.select(*columns)\ + .distinct(True)\ + .where(clause)\ + .order_by(*(c.asc() for c in columns)) return self.db.query(q) # Legacy methods for running find queries. diff --git a/setup.py b/setup.py index 0691373..fb794a4 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,6 @@ setup( include_package_data=False, zip_safe=False, install_requires=[ - "sqlalchemy >= 1.3.2, < 2.0.0", "alembic >= 0.6.2", "banal >= 1.0.1", ], diff --git a/test/test_dataset.py b/test/test_dataset.py index f7c94eb..5861fbc 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -14,7 +14,10 @@ class DatabaseTestCase(unittest.TestCase): def setUp(self): self.db = connect() self.tbl = self.db["weather"] + assert not self.db.has_table("weather") self.tbl.insert_many(TEST_DATA) + # table is only created after insert statement + assert self.db.has_table("weather") def tearDown(self): for table in self.db.tables: @@ -83,7 +86,6 @@ class DatabaseTestCase(unittest.TestCase): def test_create_table_shorthand1(self): pid = "int_id" table = self.db.get_table("foo5", pid) - assert table.table.exists assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c @@ -98,7 +100,6 @@ class DatabaseTestCase(unittest.TestCase): table = self.db.get_table( "foo6", primary_id=pid, primary_type=self.db.types.string(255) ) - assert table.table.exists assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c