summaryrefslogtreecommitdiff
path: root/examples/sharding/attribute_shard.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/sharding/attribute_shard.py')
-rw-r--r--examples/sharding/attribute_shard.py117
1 files changed, 69 insertions, 48 deletions
diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py
index 0e19b69f3..48a3dc932 100644
--- a/examples/sharding/attribute_shard.py
+++ b/examples/sharding/attribute_shard.py
@@ -1,5 +1,13 @@
-from sqlalchemy import (create_engine, Table, Column, Integer,
- String, ForeignKey, Float, DateTime)
+from sqlalchemy import (
+ create_engine,
+ Table,
+ Column,
+ Integer,
+ String,
+ ForeignKey,
+ Float,
+ DateTime,
+)
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.ext.horizontal_shard import ShardedSession
from sqlalchemy.sql import operators, visitors
@@ -12,22 +20,24 @@ import datetime
# causes the id_generator() to use the same connection as that
# of an ongoing transaction within db1.
echo = True
-db1 = create_engine('sqlite://', echo=echo, pool_threadlocal=True)
-db2 = create_engine('sqlite://', echo=echo)
-db3 = create_engine('sqlite://', echo=echo)
-db4 = create_engine('sqlite://', echo=echo)
+db1 = create_engine("sqlite://", echo=echo, pool_threadlocal=True)
+db2 = create_engine("sqlite://", echo=echo)
+db3 = create_engine("sqlite://", echo=echo)
+db4 = create_engine("sqlite://", echo=echo)
# create session function. this binds the shard ids
# to databases within a ShardedSession and returns it.
create_session = sessionmaker(class_=ShardedSession)
-create_session.configure(shards={
- 'north_america': db1,
- 'asia': db2,
- 'europe': db3,
- 'south_america': db4
-})
+create_session.configure(
+ shards={
+ "north_america": db1,
+ "asia": db2,
+ "europe": db3,
+ "south_america": db4,
+ }
+)
# mappings and tables
@@ -40,9 +50,7 @@ Base = declarative_base()
# #1. Any other method will do just as well; UUID, hilo, application-specific,
# etc.
-ids = Table(
- 'ids', Base.metadata,
- Column('nextid', Integer, nullable=False))
+ids = Table("ids", Base.metadata, Column("nextid", Integer, nullable=False))
def id_generator(ctx):
@@ -52,6 +60,7 @@ def id_generator(ctx):
conn.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1}))
return nextid
+
# table setup. we'll store a lead table of continents/cities, and a secondary
# table storing locations. a particular row will be placed in the database
# whose shard id corresponds to the 'continent'. in this setup, secondary rows
@@ -67,7 +76,7 @@ class WeatherLocation(Base):
continent = Column(String(30), nullable=False)
city = Column(String(50), nullable=False)
- reports = relationship("Report", backref='location')
+ reports = relationship("Report", backref="location")
def __init__(self, continent, city):
self.continent = continent
@@ -79,14 +88,17 @@ class Report(Base):
id = Column(Integer, primary_key=True)
location_id = Column(
- 'location_id', Integer, ForeignKey('weather_locations.id'))
- temperature = Column('temperature', Float)
+ "location_id", Integer, ForeignKey("weather_locations.id")
+ )
+ temperature = Column("temperature", Float)
report_time = Column(
- 'report_time', DateTime, default=datetime.datetime.now)
+ "report_time", DateTime, default=datetime.datetime.now
+ )
def __init__(self, temperature):
self.temperature = temperature
+
# create tables
for db in (db1, db2, db3, db4):
Base.metadata.drop_all(db)
@@ -101,10 +113,10 @@ db1.execute(ids.insert(), nextid=1)
# we'll use a straight mapping of a particular set of "country"
# attributes to shard id.
shard_lookup = {
- 'North America': 'north_america',
- 'Asia': 'asia',
- 'Europe': 'europe',
- 'South America': 'south_america'
+ "North America": "north_america",
+ "Asia": "asia",
+ "Europe": "europe",
+ "South America": "south_america",
}
@@ -139,7 +151,7 @@ def id_chooser(query, ident):
# set things up.
return [query.lazy_loaded_from.identity_token]
else:
- return ['north_america', 'asia', 'europe', 'south_america']
+ return ["north_america", "asia", "europe", "south_america"]
def query_chooser(query):
@@ -168,7 +180,7 @@ def query_chooser(query):
ids.extend(shard_lookup[v] for v in value)
if len(ids) == 0:
- return ['north_america', 'asia', 'europe', 'south_america']
+ return ["north_america", "asia", "europe", "south_america"]
else:
return ids
@@ -208,13 +220,16 @@ def _get_query_comparisons(query):
def visit_binary(binary):
# special handling for "col IN (params)"
- if binary.left in clauses and \
- binary.operator == operators.in_op and \
- hasattr(binary.right, 'clauses'):
+ if (
+ binary.left in clauses
+ and binary.operator == operators.in_op
+ and hasattr(binary.right, "clauses")
+ ):
comparisons.append(
(
- binary.left, binary.operator,
- tuple(binds[bind] for bind in binary.right.clauses)
+ binary.left,
+ binary.operator,
+ tuple(binds[bind] for bind in binary.right.clauses),
)
)
elif binary.left in clauses and binary.right in binds:
@@ -232,29 +247,33 @@ def _get_query_comparisons(query):
# into a list.
if query._criterion is not None:
visitors.traverse_depthfirst(
- query._criterion, {},
- {'bindparam': visit_bindparam,
- 'binary': visit_binary,
- 'column': visit_column}
+ query._criterion,
+ {},
+ {
+ "bindparam": visit_bindparam,
+ "binary": visit_binary,
+ "column": visit_column,
+ },
)
return comparisons
+
# further configure create_session to use these functions
create_session.configure(
shard_chooser=shard_chooser,
id_chooser=id_chooser,
- query_chooser=query_chooser
+ query_chooser=query_chooser,
)
# save and load objects!
-tokyo = WeatherLocation('Asia', 'Tokyo')
-newyork = WeatherLocation('North America', 'New York')
-toronto = WeatherLocation('North America', 'Toronto')
-london = WeatherLocation('Europe', 'London')
-dublin = WeatherLocation('Europe', 'Dublin')
-brasilia = WeatherLocation('South America', 'Brasila')
-quito = WeatherLocation('South America', 'Quito')
+tokyo = WeatherLocation("Asia", "Tokyo")
+newyork = WeatherLocation("North America", "New York")
+toronto = WeatherLocation("North America", "Toronto")
+london = WeatherLocation("Europe", "London")
+dublin = WeatherLocation("Europe", "Dublin")
+brasilia = WeatherLocation("South America", "Brasila")
+quito = WeatherLocation("South America", "Quito")
tokyo.reports.append(Report(80.0))
newyork.reports.append(Report(75))
@@ -271,12 +290,14 @@ assert t.city == tokyo.city
assert t.reports[0].temperature == 80.0
north_american_cities = sess.query(WeatherLocation).filter(
- WeatherLocation.continent == 'North America')
-assert {c.city for c in north_american_cities} == {'New York', 'Toronto'}
+ WeatherLocation.continent == "North America"
+)
+assert {c.city for c in north_american_cities} == {"New York", "Toronto"}
asia_and_europe = sess.query(WeatherLocation).filter(
- WeatherLocation.continent.in_(['Europe', 'Asia']))
-assert {c.city for c in asia_and_europe} == {'Tokyo', 'London', 'Dublin'}
+ WeatherLocation.continent.in_(["Europe", "Asia"])
+)
+assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"}
# the Report class uses a simple integer primary key. So across two databases,
# a primary key will be repeated. The "identity_token" tracks in memory
@@ -284,8 +305,8 @@ assert {c.city for c in asia_and_europe} == {'Tokyo', 'London', 'Dublin'}
newyork_report = newyork.reports[0]
tokyo_report = tokyo.reports[0]
-assert inspect(newyork_report).identity_key == (Report, (1, ), "north_america")
-assert inspect(tokyo_report).identity_key == (Report, (1, ), "asia")
+assert inspect(newyork_report).identity_key == (Report, (1,), "north_america")
+assert inspect(tokyo_report).identity_key == (Report, (1,), "asia")
# the token representing the originating shard is also available directly