浅谈Tensorflow模型的保存与恢复加载-创新互联

近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载。

成都创新互联公司专注为客户提供全方位的互联网综合服务,包含不限于成都做网站、成都网站设计、姜堰网络推广、小程序制作、姜堰网络营销、姜堰企业策划、姜堰品牌公关、搜索引擎seo、人物专访、企业宣传片、企业代运营等,从售前售中售后,我们都将竭诚为您服务,您的肯定,是我们大的嘉奖;成都创新互联公司为所有大学生创业者提供姜堰建站搭建服务,24小时服务热线:18982081108,官方网址:www.cdcxhl.com

总结一下Tensorflow常用的模型保存方式。


保存checkpoint模型文件(.ckpt)


首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型。


模型保存


使用tf.train.Saver()来保存模型文件非常方便,下面是一个简单的例子:


import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))


网页名称:浅谈Tensorflow模型的保存与恢复加载-创新互联
网站链接:http://ybzwz.com/article/degjio.html