-
在C#下使用TensorFlow.NET训练自己的数据集(2)
试听地址 https://www.xin3721.com/eschool/CSharpxin3721/
[] { num_filters }); // tf.summary.histogram("bias", b); var layer = tf.nn.conv2d(x, W, strides: new[] { 1, stride, stride, 1 }, padding: "SAME"); layer += b; return tf.nn.relu(layer); }); } /// <summary> /// Create a max pooling layer /// </summary> /// <param name="x">input to max-pooling layer</param> /// <param name="ksize">size of the max-pooling filter</param> /// <param name="stride">stride of the max-pooling filter</param> /// <param name="name">layer name</param> /// <returns>The output array</returns> private Tensor max_pool(Tensor x, int ksize, int stride, string name) { return tf.nn.max_pool(x, ksize: new[] { 1, ksize, ksize, 1 }, strides: new[] { 1, stride, stride, 1 }, padding: "SAME", name: name); } /// <summary> /// Flattens the output of the convolutional layer to be fed into fully-connected layer /// </summary> /// <param name="layer">input array</param> /// <returns>flattened array</returns> private Tensor flatten_layer(Tensor layer) { return tf_with(tf.variable_scope("Flatten_layer"), delegate { var layer_shape = layer.TensorShape; var num_features = layer_shape[new Slice(1, 4)].size; var layer_flat = tf.reshape(layer, new[] { -1, num_features }); return layer_flat; }); } /// <summary> /// Create a weight variable with appropriate initialization /// </summary> /// <param name="name"></param> /// <param name="shape"></param> /// <returns></returns> private RefVariable weight_variable(string name, int[] shape) { var initer = tf.truncated_normal_initializer(stddev: 0.01f); return tf.get_variable(name, dtype: tf.float32, shape: shape, initializer: initer); } /// <summary> /// Create a bias variable with appropriate initialization /// </summary> /// <param name="name"></param> /// <param name="shape"></param> /// <returns></returns> private RefVariable bias_variable(string name, int[] shape) { var initial = tf.constant(0f, shape: shape, dtype: tf.float32); return tf.get_variable(name, dtype: tf.float32, initializer: initial); } /// <summary> /// Create a fully-connected layer /// </summary> /// <param name="x">input from previous layer</param> /// <param name="num_units">number of hidden units in the fully-connected layer</param> /// <param name="name">layer name</param> /// <param name="use_relu">boolean to add ReLU non-linearity (or not)</param> /// <returns>The output array</returns> private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) { return tf_with(tf.variable_scope(name), delegate { var in_dim = x.shape[1]; var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units }); var b = bias_variable("b_" + name, new[] { num_units }); var layer = tf.matmul(x, W) + b; if (use_relu) layer = tf.nn.relu(layer); return layer; }); } #endregion
模型训练和模型保存
-
Batch数据集的读取,采用了 SharpCV 的cv2.imread,可以直接读取本地图像文件至NDArray,实现CV和Numpy的无缝对接;
-
使用.NET的异步线程安全队列BlockingCollection<T>,实现TensorFlow原生的队列管理器FIFOQueue;
-
在训练模型的时候,我们需要将样本从硬盘读取到内存之后,才能进行训练。我们在会话中运行多个线程,并加入队列管理器进行线程间的文件入队出队操作,并限制队列容量,主线程可以利用队列中的数据进行训练,另一个线程进行本地文件的IO读取,这样可以实现数据的读取和模型的训练是异步的,降低训练时间。
-
-
模型的保存,可以选择每轮训练都保存,或最佳训练模型保存
#region Train public void Train(Session sess) { // Number of training iterations in each epoch var num_tr_iter = (ArrayLabel_Train.Length) / batch_size; var init = tf.global_variables_initializer(); sess.run(init); var saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10); path_model = Name + "\\MODEL"; Directory.CreateDirectory(path_model); float loss_val = 100.0f; float accuracy_val = 0f; var sw = new Stopwatch(); sw.Start(); foreach (var epoch in range(epochs)) { print($"Training epoch: {epoch + 1}"); // Randomly shuffle the training data at the beginning of each epoch (ArrayFileName_Train, ArrayLabel_Train) = ShuffleArray(ArrayLabel_Train.Length, ArrayFileName_Train, ArrayLabel_Train); y_train = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Train)]; //decay learning rate if (learning_rate_step != 0) { if ((epoch != 0) && (epoch % learning_rate_step == 0)) { learning_rate_base = learning_rate_base * learning_rate_decay; if (learning_rate_base <= learning_rate_min) { learning_rate_base = learning_rate_min; } sess.run(tf.assign(learning_rate, learning_rate_base)); } } //Load local images asynchronously,use queue,improve train efficiency BlockingCollection<(NDArray c_x, NDArray c_y, int iter)> BlockC = new BlockingCollection<(NDArray C1, NDArray C2, int iter)>(TrainQueueCapa); Task.Run(() => { foreach (var iteration in range(num_tr_iter)) { var start = iteration * batch_size; var end = (iteration + 1) * batch_size; (NDArray x_batch, NDArray y_batch) = GetNextBatch(sess, ArrayFileName_Train, y_train, start, end); BlockC.Add((x_batch, y_batch, iteration)); } BlockC.CompleteAdding(); }); foreach (var item in BlockC.GetConsumingEnumerable()) { sess.run(optimizer, (x, item.c_x), (y, item.c_y)); if (item.iter % display_freq == 0) { // Calculate and display the batch loss and accuracy var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, item.c_x), new FeedItem(y, item.c_y)); loss_val = result[0]; accuracy_val = result[1]; print("CNN:" + ($"iter {item.iter.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms")); sw.Restart(); } } // Run validation after every epoch (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid)); print("CNN:" + "---------------------------------------------------------"); print("CNN:" + $"gloabl steps: {sess.run(gloabl_steps) },learning rate: {sess.run(learning_rate)}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); print("CNN:" + "---------------------------------------------------------"); if (SaverBest) { if (accuracy_val > max_accuracy) { max_accuracy = accuracy_val; saver.save(sess, path_model + "\\CNN_Best"); print("CKPT Model is save."); } } else { saver.save(sess, path_model + string.Format("\\CNN_Epoch_{0}_Loss_{1}_Acc_{2}", epoch, loss_val, accuracy_val)); print("CKPT Model is save."); } } Write_Dictionary(path_model + "\\dic.txt", Dict_Label); } private void Write_Dictionary(string path, Dictionary<Int64, string> mydic) { FileStream fs = new FileStream(path, FileMode.Create); StreamWriter sw = new StreamWriter(fs); foreach (var d in mydic) { sw.Write(d.Key + "," + d.Value + "\r\n"); } sw.Flush(); sw.Close(); fs.Close(); print("Write_Dictionary"); } private (NDArray, NDArray) Randomize(NDArray x, NDArray y) { var perm = np.random.permutation(y.shape[0]); np.random.shuffle(perm); return (x[perm], y[perm]); } private (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) { var slice = new Slice(start, end); var x_batch = x[slice]; var y_batch = y[slice]; return (x_batch, y_batch); } private unsafe (NDArray, NDArray) GetNextBatch(Session sess, string[] x, NDArray y, int start, int end) { NDArray x_batch = np.zeros(end - start, img_h, img_w, n_channels); int n = 0; for (int i = start; i < end; i++) { NDArray img4 = cv2.imread(x[i], IMREAD_COLOR.IMREAD_GRAYSCALE); x_batch[n] = sess.run(normalized, (decodeJpeg, img4)); n++; } var slice = new Slice(start, end); var y_batch = y[slice]; return (x_batch, y_batch); } #endregion
测试集预测
-
训练完成的模型对test数据集进行预测,并统计准确率
-
计算图中增加了一个提取预测结果Top-1的概率的节点,最后测试集预测的时候可以把详细的预测数据进行输出,方便实际工程中进行调试和优化。
public void Test(Session sess) { (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test)); print("CNN:" + "---------------------------------------------------------"); print("CNN:" + $"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); print("CNN:" + "---------------------------------------------------------"); (Test_Cls, Test_Data) = sess.run((cls_prediction, prob), (x, x_test)); } private void TestDataOutput() { for (int i = 0; i < ArrayLabel_Test.Length; i++) { Int64 real = ArrayLabel_Test[i]; int predict = (int)(Test_Cls[i]); var probability = Test_Data[i, predict]; string result = (real == predict) ? "OK" : "NG"; string fileName = ArrayFileName_Test[i]; string real_str = Dict_Label[real]; string predict_str = Dict_Label[predict]; print((i + 1).ToString() + "|" + "result:" + result + "|" + "real_str:" + real_str + "|" + "predict_str:" + predict_str + "|" + "probability:" + probability.GetSingle().ToString() + "|" + "fileName:" + fileName); } }
总结
本文主要是.NET下的TensorFlow在实际工业现场视觉检测项目中的应用,使用SciSharp的TensorFlow.NET构建了简单的CNN图像分类模型,该模型包含输入层、卷积与池化层、扁平化层、全连接层和输出层,这些层都是CNN分类模型的必要的层,针对工业现场的实际图像进行了分类,分类准确性较高。
完整代码可以直接用于大家自己的数据集进行训练,已经在工业现场经过大量测试,可以在GPU或CPU环境下运行,只需要更换tensorflow.dll文件即可实现训练环境的切换。
同时,训练完成的模型文件,可以使用 “CKPT+Meta” 或 冻结成“PB” 2种方式,进行现场的部署,模型部署和现场应用推理可以全部在.NET平台下进行,实现工业现场程序的无缝对接。摆脱了以往Python下 需要通过Flask搭建服务器进行数据通讯交互 的方式,现场部署应用时无需配置Python和TensorFlow的环境【无需对工业现场的原有PC升级安装一大堆环境】,整个过程全部使用传统的.NET的DLL引用的方式。