TensorFlow, perché ci sono 3 file dopo aver salvato il modello?

Dopo aver letto i documenti , ho salvato un modello in TensorFlow , ecco il mio codice demo:

 # Create some variables. v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. .. # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in file: %s" % save_path) 

ma dopo, ho trovato che ci sono 3 file

 model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta 

E non riesco a ripristinare il modello ripristinando il file model.ckpt , poiché non esiste tale file. Ecco il mio codice

 with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "/tmp/model.ckpt") 

Quindi, perché ci sono 3 file?

Prova questo:

 with tf.Session() as sess: saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta') saver.restore(sess, "/tmp/model.ckpt") 

Il metodo di salvataggio TensorFlow consente di salvare tre tipi di file perché memorizza la struttura del grafico separatamente dai valori delle variabili . Il file .meta descrive la struttura del grafico salvata, quindi è necessario importarlo prima di ripristinare il checkpoint (altrimenti non sa a quali variabili corrispondono i valori di checkpoint salvati).

In alternativa, puoi fare questo:

 # Recreate the EXACT SAME variables v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Now load the checkpoint variable values with tf.Session() as sess: saver = tf.train.Saver() saver.restore(sess, "/tmp/model.ckpt") 

Anche se non esiste un file denominato model.ckpt , si fa comunque riferimento al punto di controllo salvato con quel nome durante il ripristino. Dal codice sorgente saver.py :

Gli utenti devono solo interagire con il prefisso specificato dall’utente … invece di qualsiasi nome di percorso fisico.

  • meta file : descrive la struttura del grafo salvata, include GraphDef, SaverDef e così via; quindi applica tf.train.import_meta_graph('/tmp/model.ckpt.meta') , ripristinerà Saver e Graph .

  • indice : è una tabella immutabile stringa-stringa (tensorflow :: table :: Table). Ogni chiave è un nome di un tensore e il suo valore è un pacchetto di serie serializzato. Ogni BundleEntryProto descrive i metadati di un tensore: quale dei file “dati” contiene il contenuto di un tensore, l’offset in quel file, il checksum, alcuni dati ausiliari, ecc.

  • file di dati : è raccolta TensorBundle, salva i valori di tutte le variabili.

Sto ripristinando le formazioni vocali addestrate dal tutorial di tensorflow di Word2Vec.

Nel caso in cui siano stati creati più checkpoint:

ad esempio i file creati assomigliano a questo

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

prova questo

 def restore_session(self, session): saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta') saver.restore(session, './tmp/model.ckpt-55695') 

quando si chiama restore_session ():

 def test_word2vec(): opts = Options() with tf.Graph().as_default(), tf.Session() as session: with tf.device("/cpu:0"): model = Word2Vec(opts, session) model.restore_session(session) model.get_embedding("assistance") 

Ad esempio, se hai addestrato una CNN con dropout, puoi farlo:

 def predict(image, model_name): """ image -> single image, (width, height, channels) model_name -> model file that was saved without any extensions """ with tf.Session() as sess: saver = tf.train.import_meta_graph('./' + model_name + '.meta') saver.restore(sess, './' + model_name) # Substitute 'logits' with your model prediction = tf.argmax(logits, 1) # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})