How does tensorflow GradientTape context manager know which variables to track?

So I was trying to understand how exactly tensorflow’s GradientTape works.

I can define a custom operation in a plain python function using tensors, and it suffices only that the variable be stated as tf.Variable.

I am wondering what exactly does tensorflow do in the context manager implementation.

I have looked into the source code and got as far as this: https://github.com/tensorflow/tensorflow/blob/2c2fdd3205a8d31e5f09a71ac7eb52b8c0294a60/tensorflow/python/eager/tape.py#L52

def push_tape(tape):   """Pushes an existing tape onto the tape stack."""   pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape)  # pylint: disable=protected-access 

which seems to use some wrappings and I cannot trace further.

My question is: does tensorflow track the variables internally at the low level binding level, or is it done in python? How does the internal of variable tracking work, so that I can use the simple syntactic sugar of the below?

with tf.GradientTape as tape:    ... 
Add Comment
0 Answer(s)

Your Answer

By posting your answer, you agree to the privacy policy and terms of service.