if __name__ == "__main__":
means current file is executed under a shell instead of imported as a module.
tf.app.run()
As you can see through the file app.py
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))
Let's break line by line:
flags_passthrough = f._parse_flags(args=args)
This ensures that the argument you pass through command line is valid,e.g.
python my_model.py --data_dir='...' --max_iteration=10000
Actually, this feature is implemented based on python standard argparse
module.
main = main or sys.modules['__main__'].main
The first main
in right side of =
is the first argument of current function run(main=None, argv=None)
. While sys.modules['__main__']
means current running file(e.g. my_model.py
).
So there are two cases:
You don't have a main
function in my_model.py
Then you have to
call tf.app.run(my_main_running_function)
you have a main
function in my_model.py
. (This is mostly the case.)
Last line:
sys.exit(main(sys.argv[:1] + flags_passthrough))
ensures your main(argv)
or my_main_running_function(argv)
function is called with parsed arguments properly.
It's just a very quick wrapper that handles flag parsing and then dispatches to your own main. See the code.
There is nothing special in tf.app
. This is just a generic entry point script, which
Runs the program with an optional 'main' function and 'argv' list.
It has nothing to do with neural networks and it just calls the main function, passing through any arguments to it.
In simple terms, the job of tf.app.run()
is to first set the global flags for later usage like:
from tensorflow.python.platform import flags
f = flags.FLAGS
and then run your custom main function with a set of arguments.
For e.g. in TensorFlow NMT codebase, the very first entry point for the program execution for training/inference starts at this point (see below code)
if __name__ == "__main__":
nmt_parser = argparse.ArgumentParser()
add_arguments(nmt_parser)
FLAGS, unparsed = nmt_parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
After parsing the arguments using argparse
, with tf.app.run()
you run the function "main" which is defined like:
def main(unused_argv):
default_hparams = create_hparams(FLAGS)
train_fn = train.train
inference_fn = inference.inference
run_main(FLAGS, default_hparams, train_fn, inference_fn)
So, after setting the flags for global use, tf.app.run()
simply runs that main
function that you pass to it with argv
as its parameters.
P.S.: As Salvador Dali's answer says, it's just a good software engineering practice, I guess, although I'm not sure whether TensorFlow performs any optimized run of the main
function than that was run using normal CPython.
Google code depends on a lot on global flags being accessing in libraries/binaries/python scripts and so tf.app.run() parses out those flags to create a global state in FLAGs(or something similar) variable and then calls python main() as it should.
If they didn't have this call to tf.app.run(), then users might forget to do FLAGs parsing, leading to these libraries/binaries/scripts not having access to FLAGs they need.
2.0 Compatible Answer: If you want to use tf.app.run()
in Tensorflow 2.0
, we should use the command,
tf.compat.v1.app.run()
or you can use tf_upgrade_v2
to convert 1.x
code to 2.0
.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With