[Python] Tensorflow: मॉडल को कैसे सहेज / पुनर्स्थापित करना है?


Answers

(और बाद में) टेंसरफ्लो संस्करण 0.11.0 आरसी 1, आप https://www.tensorflow.org/programmers_guide/meta_graph अनुसार tf.train.export_meta_graph और tf.train.import_meta_graph को कॉल करके सीधे अपने मॉडल को सहेज सकते हैं और पुनर्स्थापित कर सकते हैं।

मॉडल को बचाओ

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

मॉडल को पुनर्स्थापित करें

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)
Question

टेन्सफोर्लो में एक मॉडल को प्रशिक्षित करने के बाद:

  1. आप प्रशिक्षित मॉडल को कैसे बचाते हैं?
  2. बाद में आप इस सहेजे गए मॉडल को कैसे पुनर्स्थापित करते हैं?



यहां दो मूलभूत मामलों के लिए मेरा सरल समाधान है कि आप फ़ाइल से ग्राफ़ लोड करना चाहते हैं या रनटाइम के दौरान इसे बनाना चाहते हैं।

यह उत्तर Tensorflow 0.12+ (1.0 सहित) के लिए है।

कोड में ग्राफ को पुनर्निर्माण

बचत

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

लोड हो रहा है

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

फ़ाइल से ग्राफ भी लोड हो रहा है

इस तकनीक का उपयोग करते समय, सुनिश्चित करें कि आपके सभी परतों / चरों ने स्पष्ट रूप से अद्वितीय नाम सेट किए हैं। अन्यथा Tensorflow नाम अद्वितीय बना देगा और वे फ़ाइल में संग्रहीत नामों से अलग होंगे। यह पिछली तकनीक में कोई समस्या नहीं है, क्योंकि नाम लोडिंग और बचत दोनों में समान रूप से "उलझन" होते हैं।

बचत

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

लोड हो रहा है

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection



आप यह आसान तरीका भी ले सकते हैं।

चरण 1: अपने सभी चर शुरू करें

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

चरण 2: मॉडल Saver अंदर सत्र को सहेजें और इसे सेव करें

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

चरण 3: मॉडल को पुनर्स्थापित करें

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

चरण 4: अपना चर जांचें

W1 = session.run(W1)
print(W1)

विभिन्न पायथन उदाहरण में चलते समय, उपयोग करें

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)



जैसा कि यारोस्लाव ने कहा था, आप ग्राफ़ को आयात करके ग्राफ़_डेफ़ और चेकपॉइंट से बहाल कर सकते हैं, मैन्युअल रूप से चर बनाते हैं, और उसके बाद सेवर का उपयोग कर सकते हैं।

मैंने इसे अपने व्यक्तिगत उपयोग के लिए कार्यान्वित किया, इसलिए मैं यहां कोड साझा करता हूं।

लिंक: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(यह निश्चित रूप से एक हैक है, और इस बात की कोई गारंटी नहीं है कि इस तरह से सहेजे गए मॉडल टेंसरफ्लो के भविष्य के संस्करणों में पठनीय बने रहेंगे।)




यहां सभी जवाब बहुत अच्छे हैं, लेकिन मैं दो चीजें जोड़ना चाहता हूं।

सबसे पहले, @ user7505159 के उत्तर पर विस्तृत करने के लिए, "./" आपके द्वारा बहाल किए जा रहे फ़ाइल नाम की शुरुआत में जोड़ने के लिए महत्वपूर्ण हो सकता है।

उदाहरण के लिए, आप फ़ाइल नाम में "./" के साथ ग्राफ़ को सहेज सकते हैं जैसे:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

लेकिन ग्राफ को पुनर्स्थापित करने के लिए, आपको फ़ाइल_नाम में "./" को प्रीपेड करने की आवश्यकता हो सकती है:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

आपको हमेशा "./" की आवश्यकता नहीं होगी, लेकिन यह आपके पर्यावरण और टेंसरफ्लो के संस्करण के आधार पर समस्याएं पैदा कर सकती है।

यह भी उल्लेख करना चाहता है कि सत्र को बहाल करने से पहले sess.run(tf.global_variables_initializer()) महत्वपूर्ण हो सकता है।

यदि आप सहेजे गए सत्र को पुनर्स्थापित करने का प्रयास करते समय sess.run(tf.global_variables_initializer()) चर के संबंध में कोई त्रुटि प्राप्त कर रहे हैं, तो सुनिश्चित करें कि आप sess.run(tf.global_variables_initializer()) को saver.restore(sess, save_file) पंक्ति से पहले शामिल करते हैं। यह आपको सिरदर्द बचा सकता है।




यदि आप डिफ़ॉल्ट सत्र के रूप में tf.train.MonitoredTrainingSession उपयोग करते हैं, तो आपको चीजों को सहेजने / पुनर्स्थापित करने के लिए अतिरिक्त कोड जोड़ने की आवश्यकता नहीं है। बस मॉनिटरर्ड ट्रेनिंग सत्र के कन्स्ट्रक्टर को चेकपॉइंट डीआईआर नाम पास करें, यह इन्हें संभालने के लिए सत्र हुक का उपयोग करेगा।




आदर्श निर्देशिका में graph.pbtxt . graph.pbtxt रूप में मॉडल की परिभाषा, graph.pbtxt परिभाषा, मॉडल निर्देशिका में graph.pbtxt . graph.pbtxt रूप में सहेजा गया है और tensors के संख्यात्मक मान, model.ckpt-1003418 जैसे चेकपॉइंट फ़ाइलों में सहेजे गए हैं।

मॉडल परिभाषा को tf.import_graph_def का उपयोग करके पुनर्स्थापित किया जा सकता है, और वजन को Saver का उपयोग करके बहाल किया जाता है।

हालांकि, Saver मॉडल ग्राफ से जुड़ी चर की विशेष संग्रह होल्डिंग सूची का उपयोग करता है, और यह संग्रह import_graph_def का उपयोग करके प्रारंभ नहीं किया गया है, इसलिए आप इस समय दोनों को एक साथ उपयोग नहीं कर सकते हैं (यह ठीक करने के लिए हमारे रोडमैप पर है)। अभी के लिए, आपको रयान Sepassi के दृष्टिकोण का उपयोग करना होगा - मैन्युअल रूप से समान नोड नामों के साथ एक ग्राफ का निर्माण, और वजन में लोड करने के लिए Saver का उपयोग करें।

(वैकल्पिक रूप से आप import_graph_def का उपयोग कर, मैन्युअल रूप से चर बनाने, और प्रत्येक चर के लिए tf.add_to_collection(tf.GraphKeys.VARIABLES, variable) का उपयोग करके इसका उपयोग कर हैक कर सकते हैं, फिर Saver का उपयोग कर)