Come selezionare le righe da un tensore 3D in TensorFlow?

Ho un logits tensoriale con le dimensioni [batch_size, num_rows, num_coordinates] (cioè ogni logit nel batch è una matrice). Nel mio caso la dimensione del lotto è 2, ci sono 4 righe e 4 coordinate.

 logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], [11.0, 10.0, 10.0, 30.0], [12.0, 10.0, 10.0, 20.0], [13.0, 10.0, 10.0, 20.0]], [[14.0, 11.0, 21.0, 31.0], [15.0, 11.0, 11.0, 21.0], [16.0, 11.0, 11.0, 21.0], [17.0, 11.0, 11.0, 21.0]]]) 

Voglio selezionare la prima e la seconda riga del primo lotto e la seconda e la quarta riga del secondo lotto.

 indices = tf.constant([[0, 1], [1, 3]]) 

Quindi l’output desiderato sarebbe

 logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], [11.0, 10.0, 10.0, 30.0]], [[15.0, 11.0, 11.0, 21.0], [17.0, 11.0, 11.0, 21.0]]]) 

Come posso farlo usando TensorFlow? Ho provato a utilizzare tf.gather(logits, indices) ma non ha restituito ciò che mi aspettavo. Grazie!

Questo è ansible in TensorFlow, ma leggermente scomodo, perché tf.gather() attualmente funziona solo con indici unidimensionali e seleziona solo sezioni dalla dimensione 0 di un tensore. Tuttavia, è ancora ansible risolvere il problema in modo efficiente, trasformando gli argomenti in modo che possano essere passati a tf.gather() :

 logits = ... # [2 x 4 x 4] tensor indices = tf.constant([[0, 1], [1, 3]]) # Use tf.shape() to make this work with dynamic shapes. batch_size = tf.shape(logits)[0] rows_per_batch = tf.shape(logits)[1] indices_per_batch = tf.shape(indices)[1] # Offset to add to each row in indices. We use `tf.expand_dims()` to make # this broadcast appropriately. offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1) # Convert indices and logits into appropriate form for `tf.gather()`. flattened_indices = tf.reshape(indices + offset, [-1]) flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]])) selected_rows = tf.gather(flattened_logits, flattened_indices) result = tf.reshape(selected_rows, tf.concat(0, [tf.pack([batch_size, indices_per_batch]), tf.shape(logits)[2:]])) 

Si noti che, poiché utilizza tf.reshape() e non tf.transpose() , non è necessario modificare i dati (potenzialmente di grandi dimensioni) nel tensore dei logits , quindi dovrebbe essere abbastanza efficiente.

La risposta di mrry è ottima, ma penso che con la funzione tf.gather_nd il problema possa essere risolto con molte meno righe di codice (probabilmente questa funzione non era ancora disponibile al momento della scrittura di mrry):

 logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], [11.0, 10.0, 10.0, 30.0], [12.0, 10.0, 10.0, 20.0], [13.0, 10.0, 10.0, 20.0]], [[14.0, 11.0, 21.0, 31.0], [15.0, 11.0, 11.0, 21.0], [16.0, 11.0, 11.0, 21.0], [17.0, 11.0, 11.0, 21.0]]]) indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]]) result = tf.gather_nd(logits, indices) with tf.Session() as sess: print(sess.run(result)) 

Questo stamperà

 [[[ 10. 10. 20. 20.] [ 11. 10. 10. 30.]] [[ 15. 11. 11. 21.] [ 17. 11. 11. 21.]]] 

tf.gather_nd dovrebbe essere disponibile dalla v0.10. Dai un’occhiata a questo problema di github per ulteriori discussioni su questo.