xin/pkgs/kobuddy.diff
2024-08-29 07:34:55 -06:00

184 lines
7.0 KiB
Diff

diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 68bf5db..92d375e 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -42,6 +42,9 @@ jobs:
run: |
sudo apt-get -qq update
pip install -e ".[dev]"
+
+ - uses: mxschmitt/action-tmate@v3
+
- name: Run SQLite tests
env:
DATABASE_URL: "sqlite:///:memory:"
diff --git a/dataset/database.py b/dataset/database.py
index d8a07ad..4bc31fd 100644
--- a/dataset/database.py
+++ b/dataset/database.py
@@ -106,7 +106,7 @@ class Database(object):
@property
def metadata(self):
"""Return a SQLAlchemy schema cache object."""
- return MetaData(schema=self.schema, bind=self.executable)
+ return MetaData(schema=self.schema)
@property
def in_transaction(self):
@@ -127,6 +127,8 @@ class Database(object):
"""
if not hasattr(self.local, "tx"):
self.local.tx = []
+ if self.executable.in_transaction():
+ self.executable.commit()
self.local.tx.append(self.executable.begin())
def commit(self):
diff --git a/dataset/table.py b/dataset/table.py
index 08b806b..2f27060 100644
--- a/dataset/table.py
+++ b/dataset/table.py
@@ -116,7 +116,12 @@ class Table(object):
Returns the inserted row's primary key.
"""
row = self._sync_columns(row, ensure, types=types)
- res = self.db.executable.execute(self.table.insert(row))
+ res = self.db.executable.execute(self.table.insert(), row)
+ # SQLAlchemy 2.0.0b1 removes auto commit
+ if hasattr(self.db.local, "tx") and self.db.local.tx:
+ pass
+ else:
+ self.db.executable.commit()
if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0]
return True
@@ -181,7 +186,8 @@ class Table(object):
# Insert when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
chunk = pad_chunk_columns(chunk, columns)
- self.table.insert().execute(chunk)
+ with self.db.engine.begin() as conn:
+ conn.execute(self.table.insert(), chunk)
chunk = []
def update(self, row, keys, ensure=None, types=None, return_count=False):
@@ -206,7 +212,7 @@ class Table(object):
clause = self._args_to_clause(args)
if not len(row):
return self.count(clause)
- stmt = self.table.update(whereclause=clause, values=row)
+ stmt = self.table.update().where(clause).values(row)
rp = self.db.executable.execute(stmt)
if rp.supports_sane_rowcount():
return rp.rowcount
@@ -241,10 +247,9 @@ class Table(object):
# Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys]
- stmt = self.table.update(
- whereclause=and_(True, *cl),
- values={col: bindparam(col, required=False) for col in columns},
- )
+ stmt = self.table.update()\
+ .where(and_(True, *cl))\
+ .values({col: bindparam(col, required=False) for col in columns})
self.db.executable.execute(stmt, chunk)
chunk = []
@@ -293,7 +298,7 @@ class Table(object):
if not self.exists:
return False
clause = self._args_to_clause(filters, clauses=clauses)
- stmt = self.table.delete(whereclause=clause)
+ stmt = self.table.delete().where(clause)
rp = self.db.executable.execute(stmt)
return rp.rowcount > 0
@@ -303,7 +308,7 @@ class Table(object):
self._columns = None
try:
self._table = SQLATable(
- self.name, self.db.metadata, schema=self.db.schema, autoload=True
+ self.name, self.db.metadata, schema=self.db.schema, autoload_with=self.db.engine,
)
except NoSuchTableError:
self._table = None
@@ -625,7 +630,7 @@ class Table(object):
order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, clauses=_clauses)
- query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
+ query = self.table.select().where(args).limit(_limit).offset(_offset)
if len(order_by):
query = query.order_by(*order_by)
@@ -666,7 +671,7 @@ class Table(object):
return 0
args = self._args_to_clause(kwargs, clauses=_clauses)
- query = select([func.count()], whereclause=args)
+ query = select(func.count()).where(args)
query = query.select_from(self.table)
rp = self.db.executable.execute(query)
return rp.fetchone()[0]
@@ -703,12 +708,10 @@ class Table(object):
if not len(columns):
return iter([])
- q = expression.select(
- columns,
- distinct=True,
- whereclause=clause,
- order_by=[c.asc() for c in columns],
- )
+ q = expression.select(*columns)\
+ .distinct(True)\
+ .where(clause)\
+ .order_by(*(c.asc() for c in columns))
return self.db.query(q)
# Legacy methods for running find queries.
diff --git a/setup.py b/setup.py
index 0691373..fb794a4 100644
--- a/setup.py
+++ b/setup.py
@@ -30,7 +30,6 @@ setup(
include_package_data=False,
zip_safe=False,
install_requires=[
- "sqlalchemy >= 1.3.2, < 2.0.0",
"alembic >= 0.6.2",
"banal >= 1.0.1",
],
diff --git a/test/test_dataset.py b/test/test_dataset.py
index f7c94eb..5861fbc 100644
--- a/test/test_dataset.py
+++ b/test/test_dataset.py
@@ -14,7 +14,10 @@ class DatabaseTestCase(unittest.TestCase):
def setUp(self):
self.db = connect()
self.tbl = self.db["weather"]
+ assert not self.db.has_table("weather")
self.tbl.insert_many(TEST_DATA)
+ # table is only created after insert statement
+ assert self.db.has_table("weather")
def tearDown(self):
for table in self.db.tables:
@@ -83,7 +86,6 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand1(self):
pid = "int_id"
table = self.db.get_table("foo5", pid)
- assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
@@ -98,7 +100,6 @@ class DatabaseTestCase(unittest.TestCase):
table = self.db.get_table(
"foo6", primary_id=pid, primary_type=self.db.types.string(255)
)
- assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c