python - started - tensorflow tutorial pdf




Tensorflow: cómo obtener todas las variables de rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell (2)

La forma más fácil de resolver su problema es usar el alcance variable. Los nombres de las variables dentro de un ámbito serán prefijados con su nombre. Aquí hay un pequeño fragmento de código:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

Funcionaría de la misma manera con MultiRNNCell .

EDITAR: ha cambiado tf.trainable_variables a tf.all_variables()

Tengo una configuración donde necesito inicializar un LSTM después de la inicialización principal que utiliza tf.initialize_all_variables() . Es decir, quiero llamar a tf.initialize_variables([var_list])

¿Hay alguna forma de recopilar todas las variables internas para ambos:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

¿Para que pueda inicializar SOLO estos parámetros?

La razón principal por la que quiero esto es porque no quiero reinicializar algunos valores entrenados de antes.


También puedes usar tf.get_collection() :

cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

(Copiado en parte de la respuesta de Rafal)

Tenga en cuenta que la última línea es equivalente a la lista de comprensión en el código de Rafal.

Básicamente, tensorflow almacena una colección global de variables, que puede ser recuperada por tf.all_variables() o tf.get_collection(tf.GraphKeys.VARIABLES) . Si especifica el scope (nombre del ámbito) en la función tf.get_collection() , solo obtendrá tensores (variables en este caso) en la colección cuyos ámbitos se encuentran dentro del ámbito especificado.

EDITAR: También puede usar tf.GraphKeys.TRAINABLE_VARIABLES para obtener solo variables entrenables. Pero como vanilla BasicLSTMCell no inicializa ninguna variable no entrenable, ambas serán funcionalmente equivalentes. Para obtener una lista completa de las colecciones de gráficos predeterminadas, consulte this .





tensorflow