Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient rolling sum (window aggregate) in SAS

Tags:

sql

sas

I have two tables:

  • tb_payments: contract_id, payment_date, payment_value
  • tb_reference: contract_id, reference_date

For each (contract_id, reference_date) in tb_reference, I want to create a column sum_payments as the 90 days rolling sum from tb_payments. I can accomplish this (very inefficiently) with the query below:

%let window=90;
proc sql;
    create index contract_id on tb_payments;
quit;
proc sql;
    create table tb_rolling as
    select a.contract_id,
           a.reference_date,
           (select sum(b.payment_value)
            from tb_payments as b
            where a.contract_id = b.contract_id
                  and a.reference_date - &window. < b.payment_date
                  and b.payment_date <= a.reference_date
           ) as sum_payments
    from tb_reference as a;
quit;

How can I rewrite this to reduce the time complexity, using proc sql or SAS data step?

Edit with more info:

  • I chose 90 days as the window arbitrarily, but I will perform calculations for several windows. A solution that can perform calculations for several windows at the same time would be ideal
  • Both tables can have 10+ millions of rows, and data is completely arbitrary. My SAS server is quite powerful though
  • Contract_ids can be repeated in both tables
  • The pairs (contract_id, reference_date) and (contract_id, payment_date) are unique

Edit with sample data:

%let seed=1111;
data tb_reference (drop=i);
    call streaminit(&seed.);
    do i = 1 to 10000;
        contract_id = round(rand('UNIFORM')*1000000,1);
        output;
    end;
run;
proc surveyselect data=tb_reference out=tb_payments n=5000 seed=&seed.; run;
data tb_reference(drop=i);
    format reference_date date9.;
    call streaminit(&seed.);
    set tb_reference;
    do i = 1 to 1+round(rand('UNIFORM')*4,1);
        reference_date = '01jan2016'd + round(rand('UNIFORM')*1000,1);
        output;
    end;
run;
proc sort data=tb_reference nodupkey; by contract_id reference_date; run;
data tb_payments(drop=i);
    format payment_date date9. payment_value comma20.2;
    call streaminit(&seed.);
    set tb_payments;
    do i = 1 to 1+round(rand('UNIFORM')*20,1);
        payment_date = '01jan2015'd + round(rand('UNIFORM')*1365,1);
        payment_value = round(rand('UNIFORM')*3333,0.01);
        output;
    end;
run;
proc sort data=tb_payments nodupkey; by contract_id payment_date; run;

Update: I compared my naive solution to both proposals from Quentin and Tom.

  • The merge method is quite fast and achieved over 10x speedup for n=10000. It is also very powerful, as beautifully demonstrated by Tom in his answer.
  • Hash tables are insanely fast and achieved over 500x speedup. Because my datasets are large, this is the way to go, but there's a catch: they need to fit in RAM.

If anyone needs the full testing code, feel free to send me a message.

like image 272
Will Razen Avatar asked Jan 27 '23 12:01

Will Razen


1 Answers

Here's an example of a hash approach. Since your data are already sorted, I don't think there is much benefit to the hash approach over Tom's merge approach.

General idea is to read all of the payment data into a hash table (you may run out of memory if your real data is too big), then read through the data set of reference dates. For each reference date, you look up all of the payments for that contract_id, and iterate through them, testing to see if payment date is <90 days before the reference_date, and conditionally incrementing sum_payments.

Should be noticeably faster than the SQL approach in your question, but could lose to the MERGE approach. If the data were not sorted in advance, this might beat the time for sorting both big datasets and then merging. It could handle multiple payments on the same date.

data want;
  *initialize variables for hash table ;
  call missing(payment_date,payment_value) ;

  *Load a hash table with all of the payment data ;
  if _n_=1 then do ;
    declare hash h(dataset:"tb_payments", multidata: "yes");
    h.defineKey("contract_ID");
    h.defineData("payment_date","payment_value");
    h.defineDone() ;
  end ;

  *read in the reference dates ;
  set tb_reference (keep=contract_id reference_date) ;

  *for each reference date, look up all the payments for that contract_id ;
  *and iterate through them.  If the payment date is < 90 days before reference date then ;
  *increment sum_payments ;

  sum_payments=0 ;
  rc=h.find();  
  do while (rc = 0); *found a record;
    if 0<=(reference_date-payment_date)<90 then sum_payments = sum_payments + payment_value ;
    rc=h.find_next();
  end;
run ;
like image 144
Quentin Avatar answered Feb 02 '23 14:02

Quentin