Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

SQLAlchemy sum function with bounds

In sqlalchemy (postgresql DB), I would like to create a bounded sum function, for lack of a better term. The goal is to create a running total within a defined range.

Currently, I have something that works great for calculating a running total without the bounds. Something like this:

from sqlalchemy.sql import func

foos = (
    db.query(
        Foo.id,
        Foo.points,
        Foo.timestamp,
        func.sum(Foo.points).over(order_by=Foo.timestamp).label('running_total')
    )
    .filter(...)
    .all()
)

However, I would like to be able to bound this running total to always be within a specific range, let's say [-100, 100]. So we would get something like this (see running_total):

{'timestamp': 1, 'points': 75, 'running_total': 75}
{'timestamp': 2, 'points': 50, 'running_total': 100}
{'timestamp': 3, 'points': -100, 'running_total': 0}
{'timestamp': 4, 'points': -50, 'running_total': -50}
{'timestamp': 5, 'points': -75, 'running_total': -100}

Any ideas?

like image 629
bnjmn Avatar asked Mar 09 '23 09:03

bnjmn


2 Answers

note my initial answer is wrong, see edit below:

In raw sql, you'd do this using greatest & least functions.

Something like this:

LEAST(GREATEST(SUM(myfield) OVER (window_clause), lower_bound), upper_bound)

sqlalchemy expression language allows one two write that almost identically

import sqlalchemy as sa
import sqlalchemy.ext.declarative as dec
base = dec.declarative_base()

class Foo(base):
  __tablename__ = 'foo'
  id = sa.Column(sa.Integer, primary_key=True)
  points = sa.Column(sa.Integer, nullable=False)
  timestamp = sa.Column('tstamp', sa.Integer)

upper_, lower_ = 100, -100
win_expr = func.sum(Foo.points).over(order_by=Foo.timestamp)
bound_expr = sa.func.least(sa.func.greatest(win_expr, lower_), upper_).label('bounded_running_total')

stmt = sa.select([Foo.id, Foo.points, Foo.timestamp, bound_expr])

str(stmt)
# prints output:
# SELECT foo.id, foo.points, foo.tstamp, least(greatest(sum(foo.points) OVER (ORDER BY foo.tstamp), :greatest_1), :least_1) AS bounded_running_total 
# FROM foo'


# alternatively using session.query you can also fetch results

from sqlalchemy.orm sessionmaker
DB = sessionmaker()
db = DB()
foos_stmt = dm.query(Foo.id, Foo.points, Foo.timestamp, bound_expr).filter(...)
str(foos_stmt)
# prints output:
# SELECT foo.id, foo.points, foo.tstamp, least(greatest(sum(foo.points) OVER (ORDER BY foo.tstamp), :greatest_1), :least_1) AS bounded_running_total 
# FROM foo'

foos = foos_stmt.all()

EDIT As user @pozs points out in the comments, the above does not produce the intended results.

Two alternate approaches have been presented by @pozs. Here, I've adapted the first, recursive query approach, constructed via sqlalchemy.

import sqlalchemy as sa
import sqlalchemy.ext.declarative as dec
import sqlalchemy.orm as orm
base = dec.declarative_base()

class Foo(base):
  __tablename__ = 'foo'
  id = sa.Column(sa.Integer, primary_key=True)
  points = sa.Column(sa.Integer, nullable=False)
  timestamp = sa.Column('tstamp', sa.Integer)

upper_, lower_ = 100, -100
t = sa.select([
  Foo.timestamp,
  Foo.points,
  Foo.points.label('bounded_running_sum')
]).order_by(Foo.timestamp).limit(1).cte('t', recursive=True)

t_aliased = orm.aliased(t, name='ta')

bounded_sum = t.union_all(
  sa.select([
    Foo.timestamp,
    Foo.points,
    sa.func.greatest(sa.func.least(Foo.points + t_aliased.c.bounded_running_sum, upper_), lower_)
  ]).order_by(Foo.timestamp).limit(1)
)
stmt = sa.select([bounded_sum])

# inspect the query:
from sqlalchemy.dialects import postgresql
print(stmt.compile(dialect=postgresql.dialect(), 
                   compile_kwargs={'literal_binds': True}))
# prints output: 
# WITH RECURSIVE t(tstamp, points, bounded_running_sum) AS
# ((SELECT foo.tstamp, foo.points, foo.points AS bounded_running_sum
# FROM foo ORDER BY foo.tstamp
# LIMIT 1) UNION ALL (SELECT foo.tstamp, foo.points, greatest(least(foo.points + ta.bounded_running_sum, 100), -100) AS greatest_1
# FROM foo, t AS ta ORDER BY foo.tstamp
# LIMIT 1))
# SELECT t.tstamp, t.points, t.bounded_running_sum
# FROM t

I used this link from the documentation as a reference to construct the above, which also highlights how one may use the session instead to work with recursive CTEs

This would be the pure sqlalchemy method to generate the required results.

The 2nd approach suggested by @pozs could also be used via sqlalchemy.

The solution would have to be a variant of this section from the documentation

like image 31
Haleemur Ali Avatar answered Mar 29 '23 17:03

Haleemur Ali


Unfortunately, no built-in aggregate can help you achieve your expected output with window function calls.

You could get the expected output with manually calculating the rows one-by-one with a recursive CTE:

with recursive t as (
  (select   *, points running_total
   from     foo
   order by timestamp
   limit    1)
  union all
  (select   foo.*, least(greatest(t.running_total + foo.points, -100), 100)
   from     foo, t
   where    foo.timestamp > t.timestamp
   order by foo.timestamp
   limit    1)
)
select timestamp,
       points,
       running_total
from   t;

Unfortunately, this will be very hard to achieve with SQLAlchemy.

Your other option is, to write a custom aggregate for your specific needs, like:

create function bounded_add(int_state anyelement, next_value anyelement, next_min anyelement, next_max anyelement)
  returns anyelement
  immutable
  language sql
as $func$
  select least(greatest(int_state + next_value, next_min), next_max);
$func$;

create aggregate bounded_sum(next_value anyelement, next_min anyelement, next_max anyelement)
(
    sfunc    = bounded_add,
    stype    = anyelement,
    initcond = '0'
);

With this, you just need to replace your call to sum to be a call to bounded_sum:

select timestamp,
       points,
       bounded_sum(points, -100.0, 100.0) over (order by timestamp) running_total
from   foo;

This latter solution will probably scale better too.

http://rextester.com/LKCUK93113

like image 164
pozs Avatar answered Mar 29 '23 16:03

pozs