Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Loading lightgbm model and using predict with parallel for loop freezes (Python)

I have the need to use my model to do predictions in batches and in parallel in python. If I load the model and create the data frames in a regular for loop and use the predict function it works with no issues. If I create disjoint data frames in parallel using multiprocessing in python and then use the predict function the for loop freezes indefinitely. Why does the behavior occur?

Here is a snippet of my code:

with open('models/model_test.pkl', 'rb') as fin:
    pkl_bst = pickle.load(fin)

def predict_generator(X):

    df = X

    print(df.head())
    df = (df.groupby(['user_id']).recommender_items.apply(flat_map)
          .reset_index().drop('level_1', axis=1))
    df.columns = ['user_id', 'product_id']

    print('Merge Data')
    user_lookup = pd.read_csv('data/user_lookup.csv')
    product_lookup = pd.read_csv('data/product_lookup.csv')
    product_map = dict(zip(product_lookup.product_id, product_lookup.name))

    print(user_lookup.head())

    df = pd.merge(df, user_lookup, on=['user_id'])
    df = pd.merge(df, product_lookup, on=['product_id'])
    df = df.sort_values(['user_id', 'product_id'])

    users = df.user_id.values
    items = df.product_id.values
    df.drop(['user_id', 'product_id'], axis=1, inplace=True)

    print('Prediction Step')

    prediction = pkl_bst.predict(df, num_iteration=pkl_bst.best_iteration)
    print('Prediction Complete')

    validation = pd.DataFrame(zip(users, items, prediction),
                              columns=['user', 'item', 'prediction'])
    validation['name'] = (validation.item
                          .apply(lambda x: get_mapping(x, product_map)))
    validation = pd.DataFrame(zip(validation.user,
                              zip(validation.name,
                                  validation.prediction)),
                              columns=['user', 'prediction'])
    print(validation.head())

    def get_items(x):

        sorted_list = sorted(list(x), key=lambda i: i[1], reverse=True)[:20]
        sorted_list = random.sample(sorted_list, 10)
        return [k for k, _ in sorted_list]

    relevance = validation.groupby('user').prediction.apply(get_items)
    return relevance.reset_index()

This works but is very slow:

results = []
for d in df_list_sub:
    r = predict_generator(d)
    results.append(r)

This breaks:

from multiprocessing import Pool
import tqdm
pool = Pool(processes=8)
results = []
for x in tqdm.tqdm(pool.imap_unordered(predict_generator, df_list_sub), total=len(df_list_sub)):
    results.append(x)
    pass
pool.close()
pool.join()

I would be very thankful if someone could help me.

like image 617
RDizzl3 Avatar asked Oct 29 '22 00:10

RDizzl3


1 Answers

Stumbled onto this myself as well. This is because LightGBM only allows to access the predict function from a single process. The developers explicitly added this logic because it doesn't make sense to call the predict function from multiple processes, as the prediction function already makes use of all CPU's available. Next to that, allowing for multiprocess predicting would probably result in a worse performance. More information about this can be found in this GitHub issue.

like image 164
Daan Klijn Avatar answered Jan 02 '23 20:01

Daan Klijn