Vectorize Your Sampling from a Categorical Distribution Using Gumbel-max! Use pandas.DataFrame.shift() more!

Behind this disaster of a title lies the secret to quickly sample from a categorical distribution in python!
data science
code
Author

Daniel Claborne

Published

November 28, 2022

Honestly, what a disaster of a title. I don’t know if either part in isolation would be more likely to get someone to read this, but I just wanted to make a post. Maybe I should have click-baited with something completely unrelated, oh well.

I’m currently taking Machine Learning for Trading from the Georgia Tech online CS masters program as part of my plan to motivate myself to self-study by paying them money. While much of the ML parts of the class are review for me, it has been fun to learn things about trading, as well as do some numpy/pandas exercises.\(^{[1]}\)

The class is heavy on vectorizing your code so its fast (speed is money in trading), as well working with time series. I’ll go over two things I’ve found neat/useful. One is vectorizing sampling from a categorical distribution where the rows are logits. The other is using the .shift() method of pandas DataFrames and Series.

Vectorized sampling from a categorical distribution

Okay, so the setup is you have an array that looks like this:

import numpy as np
np.random.seed(23496)

# The ML people got to me and I now call everything that gets squashed to a probability distribution 'logits'
logits = np.random.uniform(size = (1000, 10))
logits = logits/logits.sum(axis = 1)[:,None]

logits[:5,:]
array([[0.13460263, 0.12458665, 0.05453746, 0.11991544, 0.07351353,
        0.11034637, 0.07374194, 0.11460002, 0.11551094, 0.07864502],
       [0.18602867, 0.09960763, 0.02422872, 0.10095124, 0.02961313,
        0.04475981, 0.08855924, 0.11246979, 0.16960986, 0.14417191],
       [0.0142491 , 0.14630917, 0.11735343, 0.12211442, 0.11230253,
        0.12474719, 0.13253043, 0.01106296, 0.08627144, 0.13305933],
       [0.09227899, 0.15207502, 0.07677232, 0.16330634, 0.11855988,
        0.08710454, 0.05458428, 0.18425363, 0.0224089 , 0.04865609],
       [0.01826615, 0.1956786 , 0.03484824, 0.12495028, 0.11824123,
        0.01893324, 0.17954348, 0.15826364, 0.1351583 , 0.01611684]])

Each row can be seen as the bin probabilities of a categorical distribution. Now suppose we want to sample from each of those distributions. One way you might do it is by leveraging apply_along_axis:

samples = np.apply_along_axis(
    lambda x: np.random.choice(range(len(x)), p=x), 
    1, 
    logits
)

samples[:10], samples.shape
(array([3, 1, 6, 7, 8, 9, 2, 4, 5, 3]), (1000,))

Hm, okay this works, but it is basically running a for loop over the rows of the array. Generally, apply_along_axis is not what you want to be doing if speed is a concern.

So how do we vectorize this? The answer I provide here takes advantage of the Gumbel-max trick for sampling from a categorical distribution. Essentially, given probabilities \(\pi_i, i \in {0,1,...,K}, \sum_i \pi_i = 1\) if you add Gumbel distribution noise to the log of the probabilites and then take the max, it is equivalent to sampling from a categorical distribution.

Again, take the log of the probabilities, add Gumbel noise, then take the arg-max of the result.

samples = (
    np.log(logits) + \
    np.random.gumbel(size = logits.shape)
    ).argmax(axis = 1)  

samples[:10], samples.shape
(array([0, 0, 5, 2, 2, 5, 6, 9, 7, 9]), (1000,))

Lets test if this is actually faster:

%%timeit
(np.log(logits) + np.random.gumbel(size = logits.shape)).argmax(axis = 1)  
334 μs ± 671 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%%timeit
np.apply_along_axis(lambda x: np.random.choice(range(len(x)), p=x), 1, logits)  
17 ms ± 483 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Yea, so a couple orders of magnitude faster with vectorization. We should probably also check that it produces a similar distribution across many samples (and also put a plot in here to break up the wall of text). I’ll verify by doing barplots for the distribution of 1000 values drawn from 4 of the probability distributions. Brb, going down the stackoverflow wormhole because no one knows how to make plots, no one.

Ok I’m back, here is a way to make grouped barplots with seaborn:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# do 1000 draws from all distributions using gumbel-max
gumbel_draws = []

for i in range(1000):
    samples = (
        np.log(logits) + np.random.gumbel(size = logits.shape)
    ).argmax(axis = 1) 

    gumbel_draws.append(samples) 

gumbel_arr = np.array(gumbel_draws)

# ...and 1000 using apply_along_axis + np.random.choice
apply_func_draws = []

for i in range(1000):
    samples = np.apply_along_axis(
        lambda x: np.random.choice(range(len(x)), p=x), 
        1, 
        logits
    )
    apply_func_draws.append(samples)

apply_func_arr = np.array(apply_func_draws)

In the above, if you ran the two for loops separately, you would get a better sense of how much faster the vectorized code is. Now well munge these arrays into dataframes, pivot, and feed them to seaborn.

gumbel_df = pd.DataFrame(gumbel_arr[:,:4])
apply_func_df = pd.DataFrame(apply_func_arr[:,:4])

gumbel_df = pd.melt(gumbel_df, var_name = "distribution")
apply_func_df = pd.melt(apply_func_df, var_name = "distribution")

fig, axs = plt.subplots(1, 2, figsize = (14, 8))

p = sns.countplot(data = gumbel_df, x="distribution", hue="value", ax = axs[0])
p.legend(title='Category', bbox_to_anchor=(1, 1), loc='upper left')
axs[0].set_title("Using Gumbel-max")

p = sns.countplot(data = apply_func_df, x="distribution", hue="value", ax = axs[1])
p.legend(title='Category', bbox_to_anchor=(1, 1), loc='upper left')
axs[1].set_title("Using apply_along_axis + np.random.choice")

fig.tight_layout()

The distribution of drawn values should be roughly the same.

Eyeballing these, they look similar enough that I feel confident I’ve not messed up somewhere.

Finally, of note is that a modification of this trick is used a lot in training deep learning models that want to sample from a categorical distribution (Wav2vec(Baevski et al. 2020) and Dall-E(Ramesh et al. 2021) use this). I’ll go over it in another post, but tl;dr, the network learns the probabilities and max is changed to softmax to allow backpropagation.

pandas.DataFrame.shift()

You could probably just go read the docs on this function, but I’ll try to explain why its useful. We often had to compute lagged differences or ratios for trading data indexed by date. To start I’ll give some solutions that don’t work or are bad for some reason, but might seem like reasonable starts. Lets make our dataframe with a date index to play around with:

import pandas as pd

mydf = pd.DataFrame(
    {"col1":np.random.uniform(size=100), 
    "col2":np.random.uniform(size=100)}, 
    index = pd.date_range(start = "11/29/2022", periods=100)
)

mydf.head()
col1 col2
2022-11-29 0.478292 0.121631
2022-11-30 0.664101 0.781843
2022-12-01 0.245395 0.005426
2022-12-02 0.726935 0.532795
2022-12-03 0.658744 0.970972

Now, suppose we want to compute the lag 1 difference. Specifically, make a new series \(s\) where \(s[t] = col1[t] - col2[t-1]: t > 0\), \(s[0] =\) NaN. Naive first attempt:

mydf["col1"][1:] - mydf["col2"][:-1]
2022-11-29         NaN
2022-11-30   -0.117742
2022-12-01    0.239969
2022-12-02    0.194140
2022-12-03   -0.312228
                ...   
2023-03-04   -0.214499
2023-03-05    0.361584
2023-03-06   -0.199390
2023-03-07    0.272564
2023-03-08         NaN
Freq: D, Length: 100, dtype: float64

Uh, so what happened here? Well, pandas does subtraction by index, like a join, so we just subtracted the values at the same dates, but the first and last dates were missing from col1 and col2 respectively, so we get NaN at those dates. Clearly this is not what we want.

Another option converts to numpy, this is essentially just a way to move to element-wise addition:

lag1_arr = np.array(mydf["col1"][1:]) - np.array(mydf["col2"][:-1])
lag1_arr[:5], lag1_arr.shape
(array([ 0.54247038, -0.53644839,  0.72150904,  0.12594842, -0.225511  ]),
 (99,))

Of course, this is not the same length as the series, so we have to do some finagling to get it to look right.

# prepend a NaN
lag1_arr = np.insert(lag1_arr, 0, np.nan)
lag1_arr[:5], lag1_arr.shape
(array([        nan,  0.54247038, -0.53644839,  0.72150904,  0.12594842]),
 (100,))

Ok, its the same length and has the right values so we can put it back in the dataframe as a column or create a new series (and add the index again)

# make a new series
lag1_series = pd.Series(lag1_arr, index=mydf.index)

# or make a new column
# mydf["col3"] = lag1_arr

Alright, but this looks kinda ugly, we can do the same thing much more cleanly with the .shift() method of pandas DataFrames and Series. .shift(N) does what it sounds like, it shifts the values N places forward (or backward for negative values), but keeps the indices of the series/dataframe fixed.

mydf["col1"].shift(3)
2022-11-29         NaN
2022-11-30         NaN
2022-12-01         NaN
2022-12-02    0.478292
2022-12-03    0.664101
                ...   
2023-03-04    0.807256
2023-03-05    0.992331
2023-03-06    0.974969
2023-03-07    0.339215
2023-03-08    0.625530
Freq: D, Name: col1, Length: 100, dtype: float64

With this we can easily compute the lag 1 difference, keeping the indices and such.

# difference
mydf["col1"] - mydf["col2"].shift(1)

# lag 3 ratio
mydf["col1"]/mydf["col2"].shift(3)
2022-11-29         NaN
2022-11-30         NaN
2022-12-01         NaN
2022-12-02    5.976563
2022-12-03    0.842553
                ...   
2023-03-04    0.478158
2023-03-05    0.635103
2023-03-06    0.694335
2023-03-07    0.804248
2023-03-08    3.747774
Freq: D, Length: 100, dtype: float64

This lag-N difference or ratio is extremely common and honestly I can’t believe I hadn’t been using .shift() more.


\(^{[1]}\)I am not affiliated with GA-Tech beyond the new washing machine and jacuzzi they gave me to advertise their OMSCS program

References

Baevski, Alexei, Henry Zhou, Abdelrahman Mohamed, and Michael Auli. 2020. “Wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.” arXiv:2006.11477 [Cs, Eess], October. http://arxiv.org/abs/2006.11477.
Ramesh, Aditya, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, and Ilya Sutskever. 2021. “Zero-Shot Text-to-Image Generation.” arXiv. https://doi.org/10.48550/arXiv.2102.12092.