首页 热点资讯 义务教育 高等教育 出国留学 考研考公
您的当前位置:首页正文

tensorflow同时载入多个模型

2024-12-18 来源:华拓网

如果在一个项目中同时导入多个模型,会报错,应该是graph冲突,所以需要给每个模型单独新建graph

在这里,tf.variable_scope里面的名称必须和保存的模型中的scope名是一致的

from a_model import model as model1
from b_model import model as model2
import tensorflow as tf

graph1=tf.Graph()
graph2=tf.Graph()

with tf.variable_scope('scope_a'):
    m_a = model1()
    ...

with tf.variable_scope('scope_b'):
    m_b = model2()
    ...

t_vars = tf.global_variables()
a_vars = [var for var in t_vars if var.name.startswith('scope_a')]
b_vars = [var for var in t_vars if var.name.startswith('scope_b')]

model1_path = 'model1/checkpoint.ckpt-000'
with tf.Session() as sess1:
    with graph1.as_default():
        saver1 = tf.train.Saver(a_vars)
        saver1.restore(sess1, model1_path)
    reader = tf.train.NewCheckpointReader(model1_path)
    print(reader.debug_string().decode("utf-8"))

model2_path = 'model2/checkpoint.ckpt-000'
with tf.Session() as sess2:
    with graph2.as_default():
        saver2 = tf.train.Saver(b_vars)
        saver2.restore(sess2, model2_path)
    reader = tf.train.NewCheckpointReader(model2_path)
    print(reader.debug_string().decode("utf-8"))

...

通过tf.train.NewCheckpointReader来打印载入的模型中所保存的参数以及变量名

显示全文