Sto costruendo un perceptron multistrato semplice con TensorFlow, e ho anche bisogno di ottenere i gradienti (o il segnale di errore) della perdita agli ingressi della rete neurale.
Ecco il mio codice, che funziona:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y)) optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost) ... for i in range(epochs): .... for batch in batches: ... sess.run(optimizer, feed_dict=feed_dict) grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0]
(modificato per includere il ciclo di allenamento)
Senza l’ultima riga ( grads_wrt_input...
), questo funziona molto velocemente su una macchina CUDA. Tuttavia, tf.gradients()
riduce notevolmente le prestazioni di dieci volte o più.
Ricordo che i segnali di errore ai nodes sono calcolati come valori intermedi nell’algoritmo di backpropagation, e l’ho fatto con successo usando la libreria Java DeepLearning4j. Avevo anche l’impressione che si trattasse di una leggera modifica al grafico di calcolo già realizzato optimizer
.
Come può essere reso più veloce o esiste un altro modo per calcolare i gradienti della perdita rispetto agli input?
La funzione tf.gradients()
crea un nuovo grafico di backpropagation ogni volta che viene chiamato, quindi il motivo del rallentamento è che TensorFlow deve analizzare un nuovo grafico su ogni iterazione del ciclo. (Questo può essere sorprendentemente costoso: la versione corrente di TensorFlow è ottimizzata per eseguire lo stesso grafico un gran numero di volte.)
Fortunatamente la soluzione è semplice: basta calcolare i gradienti una volta, fuori dal ciclo. Puoi ristrutturare il tuo codice come segue:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y)) optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost) grads_wrt_input_tensor = tf.gradients(cost, self.x)[0] # ... for i in range(epochs): # ... for batch in batches: # ... _, grads_wrt_input = sess.run([optimizer, grads_wrt_input_tensor], feed_dict=feed_dict)
Si noti che, per le prestazioni, ho anche combinato le due chiamate sess.run()
. Ciò garantisce che la propagazione in avanti e gran parte della backpropagation saranno riutilizzate.
Per tf.get_default_graph().finalize()
, un consiglio per trovare bug di prestazioni come questo è chiamare tf.get_default_graph().finalize()
prima di iniziare il ciclo di allenamento. Ciò solleverà un’eccezione se si aggiungono inavvertitamente nodes al grafico, il che rende più facile rintracciare la causa di questi bug.