184 lines
7.0 KiB
Diff
184 lines
7.0 KiB
Diff
|
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
|
||
|
index 68bf5db..92d375e 100644
|
||
|
--- a/.github/workflows/build.yml
|
||
|
+++ b/.github/workflows/build.yml
|
||
|
@@ -42,6 +42,9 @@ jobs:
|
||
|
run: |
|
||
|
sudo apt-get -qq update
|
||
|
pip install -e ".[dev]"
|
||
|
+
|
||
|
+ - uses: mxschmitt/action-tmate@v3
|
||
|
+
|
||
|
- name: Run SQLite tests
|
||
|
env:
|
||
|
DATABASE_URL: "sqlite:///:memory:"
|
||
|
diff --git a/dataset/database.py b/dataset/database.py
|
||
|
index d8a07ad..4bc31fd 100644
|
||
|
--- a/dataset/database.py
|
||
|
+++ b/dataset/database.py
|
||
|
@@ -106,7 +106,7 @@ class Database(object):
|
||
|
@property
|
||
|
def metadata(self):
|
||
|
"""Return a SQLAlchemy schema cache object."""
|
||
|
- return MetaData(schema=self.schema, bind=self.executable)
|
||
|
+ return MetaData(schema=self.schema)
|
||
|
|
||
|
@property
|
||
|
def in_transaction(self):
|
||
|
@@ -127,6 +127,8 @@ class Database(object):
|
||
|
"""
|
||
|
if not hasattr(self.local, "tx"):
|
||
|
self.local.tx = []
|
||
|
+ if self.executable.in_transaction():
|
||
|
+ self.executable.commit()
|
||
|
self.local.tx.append(self.executable.begin())
|
||
|
|
||
|
def commit(self):
|
||
|
diff --git a/dataset/table.py b/dataset/table.py
|
||
|
index 08b806b..2f27060 100644
|
||
|
--- a/dataset/table.py
|
||
|
+++ b/dataset/table.py
|
||
|
@@ -116,7 +116,12 @@ class Table(object):
|
||
|
Returns the inserted row's primary key.
|
||
|
"""
|
||
|
row = self._sync_columns(row, ensure, types=types)
|
||
|
- res = self.db.executable.execute(self.table.insert(row))
|
||
|
+ res = self.db.executable.execute(self.table.insert(), row)
|
||
|
+ # SQLAlchemy 2.0.0b1 removes auto commit
|
||
|
+ if hasattr(self.db.local, "tx") and self.db.local.tx:
|
||
|
+ pass
|
||
|
+ else:
|
||
|
+ self.db.executable.commit()
|
||
|
if len(res.inserted_primary_key) > 0:
|
||
|
return res.inserted_primary_key[0]
|
||
|
return True
|
||
|
@@ -181,7 +186,8 @@ class Table(object):
|
||
|
# Insert when chunk_size is fulfilled or this is the last row
|
||
|
if len(chunk) == chunk_size or index == len(rows) - 1:
|
||
|
chunk = pad_chunk_columns(chunk, columns)
|
||
|
- self.table.insert().execute(chunk)
|
||
|
+ with self.db.engine.begin() as conn:
|
||
|
+ conn.execute(self.table.insert(), chunk)
|
||
|
chunk = []
|
||
|
|
||
|
def update(self, row, keys, ensure=None, types=None, return_count=False):
|
||
|
@@ -206,7 +212,7 @@ class Table(object):
|
||
|
clause = self._args_to_clause(args)
|
||
|
if not len(row):
|
||
|
return self.count(clause)
|
||
|
- stmt = self.table.update(whereclause=clause, values=row)
|
||
|
+ stmt = self.table.update().where(clause).values(row)
|
||
|
rp = self.db.executable.execute(stmt)
|
||
|
if rp.supports_sane_rowcount():
|
||
|
return rp.rowcount
|
||
|
@@ -241,10 +247,9 @@ class Table(object):
|
||
|
# Update when chunk_size is fulfilled or this is the last row
|
||
|
if len(chunk) == chunk_size or index == len(rows) - 1:
|
||
|
cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys]
|
||
|
- stmt = self.table.update(
|
||
|
- whereclause=and_(True, *cl),
|
||
|
- values={col: bindparam(col, required=False) for col in columns},
|
||
|
- )
|
||
|
+ stmt = self.table.update()\
|
||
|
+ .where(and_(True, *cl))\
|
||
|
+ .values({col: bindparam(col, required=False) for col in columns})
|
||
|
self.db.executable.execute(stmt, chunk)
|
||
|
chunk = []
|
||
|
|
||
|
@@ -293,7 +298,7 @@ class Table(object):
|
||
|
if not self.exists:
|
||
|
return False
|
||
|
clause = self._args_to_clause(filters, clauses=clauses)
|
||
|
- stmt = self.table.delete(whereclause=clause)
|
||
|
+ stmt = self.table.delete().where(clause)
|
||
|
rp = self.db.executable.execute(stmt)
|
||
|
return rp.rowcount > 0
|
||
|
|
||
|
@@ -303,7 +308,7 @@ class Table(object):
|
||
|
self._columns = None
|
||
|
try:
|
||
|
self._table = SQLATable(
|
||
|
- self.name, self.db.metadata, schema=self.db.schema, autoload=True
|
||
|
+ self.name, self.db.metadata, schema=self.db.schema, autoload_with=self.db.engine,
|
||
|
)
|
||
|
except NoSuchTableError:
|
||
|
self._table = None
|
||
|
@@ -625,7 +630,7 @@ class Table(object):
|
||
|
|
||
|
order_by = self._args_to_order_by(order_by)
|
||
|
args = self._args_to_clause(kwargs, clauses=_clauses)
|
||
|
- query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
|
||
|
+ query = self.table.select().where(args).limit(_limit).offset(_offset)
|
||
|
if len(order_by):
|
||
|
query = query.order_by(*order_by)
|
||
|
|
||
|
@@ -666,7 +671,7 @@ class Table(object):
|
||
|
return 0
|
||
|
|
||
|
args = self._args_to_clause(kwargs, clauses=_clauses)
|
||
|
- query = select([func.count()], whereclause=args)
|
||
|
+ query = select(func.count()).where(args)
|
||
|
query = query.select_from(self.table)
|
||
|
rp = self.db.executable.execute(query)
|
||
|
return rp.fetchone()[0]
|
||
|
@@ -703,12 +708,10 @@ class Table(object):
|
||
|
if not len(columns):
|
||
|
return iter([])
|
||
|
|
||
|
- q = expression.select(
|
||
|
- columns,
|
||
|
- distinct=True,
|
||
|
- whereclause=clause,
|
||
|
- order_by=[c.asc() for c in columns],
|
||
|
- )
|
||
|
+ q = expression.select(*columns)\
|
||
|
+ .distinct(True)\
|
||
|
+ .where(clause)\
|
||
|
+ .order_by(*(c.asc() for c in columns))
|
||
|
return self.db.query(q)
|
||
|
|
||
|
# Legacy methods for running find queries.
|
||
|
diff --git a/setup.py b/setup.py
|
||
|
index 0691373..fb794a4 100644
|
||
|
--- a/setup.py
|
||
|
+++ b/setup.py
|
||
|
@@ -30,7 +30,6 @@ setup(
|
||
|
include_package_data=False,
|
||
|
zip_safe=False,
|
||
|
install_requires=[
|
||
|
- "sqlalchemy >= 1.3.2, < 2.0.0",
|
||
|
"alembic >= 0.6.2",
|
||
|
"banal >= 1.0.1",
|
||
|
],
|
||
|
diff --git a/test/test_dataset.py b/test/test_dataset.py
|
||
|
index f7c94eb..5861fbc 100644
|
||
|
--- a/test/test_dataset.py
|
||
|
+++ b/test/test_dataset.py
|
||
|
@@ -14,7 +14,10 @@ class DatabaseTestCase(unittest.TestCase):
|
||
|
def setUp(self):
|
||
|
self.db = connect()
|
||
|
self.tbl = self.db["weather"]
|
||
|
+ assert not self.db.has_table("weather")
|
||
|
self.tbl.insert_many(TEST_DATA)
|
||
|
+ # table is only created after insert statement
|
||
|
+ assert self.db.has_table("weather")
|
||
|
|
||
|
def tearDown(self):
|
||
|
for table in self.db.tables:
|
||
|
@@ -83,7 +86,6 @@ class DatabaseTestCase(unittest.TestCase):
|
||
|
def test_create_table_shorthand1(self):
|
||
|
pid = "int_id"
|
||
|
table = self.db.get_table("foo5", pid)
|
||
|
- assert table.table.exists
|
||
|
assert len(table.table.columns) == 1, table.table.columns
|
||
|
assert pid in table.table.c, table.table.c
|
||
|
|
||
|
@@ -98,7 +100,6 @@ class DatabaseTestCase(unittest.TestCase):
|
||
|
table = self.db.get_table(
|
||
|
"foo6", primary_id=pid, primary_type=self.db.types.string(255)
|
||
|
)
|
||
|
- assert table.table.exists
|
||
|
assert len(table.table.columns) == 1, table.table.columns
|
||
|
assert pid in table.table.c, table.table.c
|
||
|
|