TensorFlow: prestazioni lente quando si ottengono gradienti agli input

Sto costruendo un perceptron multistrato semplice con TensorFlow, e ho anche bisogno di ottenere i gradienti (o il segnale di errore) della perdita agli ingressi della rete neurale.

Ecco il mio codice, che funziona:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y)) optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost) ... for i in range(epochs): .... for batch in batches: ... sess.run(optimizer, feed_dict=feed_dict) grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0] 

(modificato per includere il ciclo di allenamento)

Senza l’ultima riga ( grads_wrt_input... ), questo funziona molto velocemente su una macchina CUDA. Tuttavia, tf.gradients() riduce notevolmente le prestazioni di dieci volte o più.

Ricordo che i segnali di errore ai nodes sono calcolati come valori intermedi nell’algoritmo di backpropagation, e l’ho fatto con successo usando la libreria Java DeepLearning4j. Avevo anche l’impressione che si trattasse di una leggera modifica al grafico di calcolo già realizzato optimizer .

Come può essere reso più veloce o esiste un altro modo per calcolare i gradienti della perdita rispetto agli input?

La funzione tf.gradients() crea un nuovo grafico di backpropagation ogni volta che viene chiamato, quindi il motivo del rallentamento è che TensorFlow deve analizzare un nuovo grafico su ogni iterazione del ciclo. (Questo può essere sorprendentemente costoso: la versione corrente di TensorFlow è ottimizzata per eseguire lo stesso grafico un gran numero di volte.)

Fortunatamente la soluzione è semplice: basta calcolare i gradienti una volta, fuori dal ciclo. Puoi ristrutturare il tuo codice come segue:

 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y)) optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost) grads_wrt_input_tensor = tf.gradients(cost, self.x)[0] # ... for i in range(epochs): # ... for batch in batches: # ... _, grads_wrt_input = sess.run([optimizer, grads_wrt_input_tensor], feed_dict=feed_dict) 

Si noti che, per le prestazioni, ho anche combinato le due chiamate sess.run() . Ciò garantisce che la propagazione in avanti e gran parte della backpropagation saranno riutilizzate.


Per tf.get_default_graph().finalize() , un consiglio per trovare bug di prestazioni come questo è chiamare tf.get_default_graph().finalize() prima di iniziare il ciclo di allenamento. Ciò solleverà un’eccezione se si aggiungono inavvertitamente nodes al grafico, il che rende più facile rintracciare la causa di questi bug.