Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to find built-in function source code in pytorch

Tags:

pytorch

I am trying to do research on batch normalization, and had to make some modifications for the pytorch BN code. I dig into the pytorch code and got stuck with torch.nn.functional.batch_norm, which references torch.batch_norm.

The problem is that torch.batch_norm cannot be further found in the torch library. Is there any way I can find the source code of this built-in function and re-implement it? Thanks!

like image 316
StoneFree Avatar asked Mar 03 '23 08:03

StoneFree


1 Answers

It's there, but it's not defined in Python. They're defined in C++ in the aten/ directories.

For CPU, the implementation (one of them, it depends on whether or not the input is contiguous) is here: https://github.com/pytorch/pytorch/blob/420b37f3c67950ed93cd8aa7a12e673fcfc5567b/aten/src/ATen/native/Normalization.cpp#L61-L126

For CUDA, the implementation is here: https://github.com/pytorch/pytorch/blob/7aae51cdedcbf0df5a7a8bf50a947237ac4b3ee8/aten/src/ATen/native/cudnn/BatchNorm.cpp#L52-L143

like image 51
JoshVarty Avatar answered Mar 11 '23 10:03

JoshVarty