Simplify aggregate code a bit.
authorAndres Freund <andres@anarazel.de>
Thu, 3 Aug 2017 22:23:40 +0000 (15:23 -0700)
committerAndres Freund <andres@anarazel.de>
Fri, 1 Sep 2017 06:22:35 +0000 (23:22 -0700)
src/backend/executor/nodeAgg.c
src/include/nodes/execnodes.h

index 1783f38f14717ffb68562cd4ef9d152ec9e56dc7..7e521459d621299aba19befa2fae38e342728bc2 100644 (file)
@@ -522,13 +522,13 @@ static void select_current_set(AggState *aggstate, int setno, bool is_hash);
 static void initialize_phase(AggState *aggstate, int newphase);
 static TupleTableSlot *fetch_input_tuple(AggState *aggstate);
 static void initialize_aggregates(AggState *aggstate,
-                     AggStatePerGroup pergroup,
-                     int numReset);
+                     AggStatePerGroup *pergroups,
+                     bool isHash, int numReset);
 static void advance_transition_function(AggState *aggstate,
                            AggStatePerTrans pertrans,
                            AggStatePerGroup pergroupstate);
-static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup,
-                  AggStatePerGroup *pergroups);
+static void advance_aggregates(AggState *aggstate, AggStatePerGroup *sort_pergroups,
+                  AggStatePerGroup *hash_pergroups);
 static void advance_combine_function(AggState *aggstate,
                         AggStatePerTrans pertrans,
                         AggStatePerGroup pergroupstate);
@@ -782,15 +782,14 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
  * If there are multiple grouping sets, we initialize only the first numReset
  * of them (the grouping sets are ordered so that the most specific one, which
  * is reset most often, is first). As a convenience, if numReset is 0, we
- * reinitialize all sets. numReset is -1 to initialize a hashtable entry, in
- * which case the caller must have used select_current_set appropriately.
+ * reinitialize all sets.
  *
  * When called, CurrentMemoryContext should be the per-query context.
  */
 static void
 initialize_aggregates(AggState *aggstate,
-                     AggStatePerGroup pergroup,
-                     int numReset)
+                     AggStatePerGroup *pergroups,
+                     bool isHash, int numReset)
 {
    int         transno;
    int         numGroupingSets = Max(aggstate->phase->numsets, 1);
@@ -801,30 +800,18 @@ initialize_aggregates(AggState *aggstate,
    if (numReset == 0)
        numReset = numGroupingSets;
 
-   for (transno = 0; transno < numTrans; transno++)
+   for (setno = 0; setno < numReset; setno++)
    {
-       AggStatePerTrans pertrans = &transstates[transno];
-
-       if (numReset < 0)
-       {
-           AggStatePerGroup pergroupstate;
+       AggStatePerGroup pergroup = pergroups[setno];
 
-           pergroupstate = &pergroup[transno];
+       select_current_set(aggstate, setno, isHash);
 
-           initialize_aggregate(aggstate, pertrans, pergroupstate);
-       }
-       else
+       for (transno = 0; transno < numTrans; transno++)
        {
-           for (setno = 0; setno < numReset; setno++)
-           {
-               AggStatePerGroup pergroupstate;
-
-               pergroupstate = &pergroup[transno + (setno * numTrans)];
-
-               select_current_set(aggstate, setno, false);
+           AggStatePerTrans pertrans = &transstates[transno];
+           AggStatePerGroup pergroupstate = &pergroup[transno];
 
-               initialize_aggregate(aggstate, pertrans, pergroupstate);
-           }
+           initialize_aggregate(aggstate, pertrans, pergroupstate);
        }
    }
 }
@@ -965,7 +952,7 @@ advance_transition_function(AggState *aggstate,
  * When called, CurrentMemoryContext should be the per-query context.
  */
 static void
-advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGroup *pergroups)
+advance_aggregates(AggState *aggstate, AggStatePerGroup *sort_pergroups, AggStatePerGroup *hash_pergroups)
 {
    int         transno;
    int         setno = 0;
@@ -1002,7 +989,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
        {
            /* DISTINCT and/or ORDER BY case */
            Assert(slot->tts_nvalid >= (pertrans->numInputs + inputoff));
-           Assert(!pergroups);
+           Assert(!hash_pergroups);
 
            /*
             * If the transfn is strict, we want to check for nullity before
@@ -1063,9 +1050,9 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
                fcinfo->argnull[i + 1] = slot->tts_isnull[i + inputoff];
            }
 
-           if (pergroup)
+           if (sort_pergroups)
            {
-               /* advance transition states for ordered grouping */
+               /* advance transition states for ordered grouping  */
 
                for (setno = 0; setno < numGroupingSets; setno++)
                {
@@ -1073,13 +1060,13 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
 
                    select_current_set(aggstate, setno, false);
 
-                   pergroupstate = &pergroup[transno + (setno * numTrans)];
+                   pergroupstate = &sort_pergroups[setno][transno];
 
                    advance_transition_function(aggstate, pertrans, pergroupstate);
                }
            }
 
-           if (pergroups)
+           if (hash_pergroups)
            {
                /* advance transition states for hashed grouping */
 
@@ -1089,7 +1076,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
 
                    select_current_set(aggstate, setno, true);
 
-                   pergroupstate = &pergroups[setno][transno];
+                   pergroupstate = &hash_pergroups[setno][transno];
 
                    advance_transition_function(aggstate, pertrans, pergroupstate);
                }
@@ -2061,8 +2048,8 @@ lookup_hash_entry(AggState *aggstate)
            MemoryContextAlloc(perhash->hashtable->tablecxt,
                               sizeof(AggStatePerGroupData) * aggstate->numtrans);
        /* initialize aggregates for new tuple group */
-       initialize_aggregates(aggstate, (AggStatePerGroupentry->additional,
-                             -1);
+       initialize_aggregates(aggstate, (AggStatePerGroup*) &entry->additional,
+                             true, 1);
    }
 
    return entry;
@@ -2146,7 +2133,7 @@ agg_retrieve_direct(AggState *aggstate)
    ExprContext *econtext;
    ExprContext *tmpcontext;
    AggStatePerAgg peragg;
-   AggStatePerGroup pergroup;
+   AggStatePerGroup *pergroups;
    AggStatePerGroup *hash_pergroups = NULL;
    TupleTableSlot *outerslot;
    TupleTableSlot *firstSlot;
@@ -2169,7 +2156,7 @@ agg_retrieve_direct(AggState *aggstate)
    tmpcontext = aggstate->tmpcontext;
 
    peragg = aggstate->peragg;
-   pergroup = aggstate->pergroup;
+   pergroups = aggstate->pergroups;
    firstSlot = aggstate->ss.ss_ScanTupleSlot;
 
    /*
@@ -2371,7 +2358,7 @@ agg_retrieve_direct(AggState *aggstate)
            /*
             * Initialize working state for a new input tuple group.
             */
-           initialize_aggregates(aggstate, pergroup, numReset);
+           initialize_aggregates(aggstate, pergroups, false, numReset);
 
            if (aggstate->grp_firstTuple != NULL)
            {
@@ -2408,9 +2395,9 @@ agg_retrieve_direct(AggState *aggstate)
                        hash_pergroups = NULL;
 
                    if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
-                       combine_aggregates(aggstate, pergroup);
+                       combine_aggregates(aggstate, pergroups[0]);
                    else
-                       advance_aggregates(aggstate, pergroup, hash_pergroups);
+                       advance_aggregates(aggstate, pergroups, hash_pergroups);
 
                    /* Reset per-input-tuple context after each tuple */
                    ResetExprContext(tmpcontext);
@@ -2474,7 +2461,7 @@ agg_retrieve_direct(AggState *aggstate)
 
        finalize_aggregates(aggstate,
                            peragg,
-                           pergroup + (currentSet * aggstate->numtrans));
+                           pergroups[currentSet]);
 
        /*
         * If there's no row to project right now, we must continue rather
@@ -2715,7 +2702,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
    aggstate->curpertrans = NULL;
    aggstate->input_done = false;
    aggstate->agg_done = false;
-   aggstate->pergroup = NULL;
+   aggstate->pergroups = NULL;
    aggstate->grp_firstTuple = NULL;
    aggstate->sort_in = NULL;
    aggstate->sort_out = NULL;
@@ -3019,13 +3006,17 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 
    if (node->aggstrategy != AGG_HASHED)
    {
-       AggStatePerGroup pergroup;
+       AggStatePerGroup *pergroups =
+           (AggStatePerGroup*) palloc0(sizeof(AggStatePerGroup)
+                                       * numGroupingSets);
 
-       pergroup = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData)
-                                             * numaggs
-                                             * numGroupingSets);
+       for (i = 0; i < numGroupingSets; i++)
+       {
+           pergroups[i] = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData)
+                                                    * numaggs);
+       }
 
-       aggstate->pergroup = pergroup;
+       aggstate->pergroups = pergroups;
    }
 
    /*
@@ -3988,8 +3979,11 @@ ExecReScanAgg(AggState *node)
        /*
         * Reset the per-group state (in particular, mark transvalues null)
         */
-       MemSet(node->pergroup, 0,
-              sizeof(AggStatePerGroupData) * node->numaggs * numGroupingSets);
+       for (setno = 0; setno < numGroupingSets; setno++)
+       {
+           MemSet(node->pergroups[setno], 0,
+                  sizeof(AggStatePerGroupData) * node->numaggs);
+       }
 
        /* reset to phase 1 */
        initialize_phase(node, 1);
index 8ae8179ee7b76acab281ee931ac5378149b91297..bc5874f1ee4251ec80e99410a152ec70b60dd2d0 100644 (file)
@@ -1823,13 +1823,15 @@ typedef struct AggState
    Tuplesortstate *sort_out;   /* input is copied here for next phase */
    TupleTableSlot *sort_slot;  /* slot for sort results */
    /* these fields are used in AGG_PLAIN and AGG_SORTED modes: */
-   AggStatePerGroup pergroup;  /* per-Aggref-per-group working state */
+   AggStatePerGroup *pergroups;    /* grouping set indexed array of per-group
+                                    * pointers */
    HeapTuple   grp_firstTuple; /* copy of first tuple of current group */
    /* these fields are used in AGG_HASHED and AGG_MIXED modes: */
    bool        table_filled;   /* hash table filled yet? */
    int         num_hashes;
    AggStatePerHash perhash;
-   AggStatePerGroup *hash_pergroup;    /* array of per-group pointers */
+   AggStatePerGroup *hash_pergroup;    /* grouping set indexed array of
+                                        * per-group pointers */
    /* support for evaluation of agg inputs */
    TupleTableSlot *evalslot;   /* slot for agg inputs */
    ProjectionInfo *evalproj;   /* projection machinery */