由YouTube8M的視頻模型到音頻模型轉化
youtube8M的接口的參數較為容易設置,首先文件夾的train.py文件
import json
import os
import timeimport eval_util
import export_model
import losses
import frame_level_models
import video_level_models
import readers
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow import app
from tensorflow import flags
from tensorflow import gfile
from tensorflow import logging
from tensorflow.python.client import device_lib
import utils
這些包含了引入的文件夾的其他寫好的py文件和需要用到的庫,如果沒有的話pip安裝即可。
1.模型的保存地址設置:
flags.DEFINE_string("train_dir", "/tmp/yt8m_model/","The directory to save the model files in.")
需要多提一句的是flags這個模塊是用于執行py程序的外部參數設置的交互,具體可以參見tf.app.flags/tf.flags
2.數據集的存儲地址指定:
flags.DEFINE_string("train_data_pattern", "E:/Audio_project/audioset/audioset_v1_embeddings/bal_train/*.tfrecord","File glob for the training dataset. If the files refer to Frame Level ""features (i.e. tensorflow.SequenceExample), then set --reader_type ""format. The (Sequence)Examples are expected to have 'rgb' byte array ""sequence feature as well as a 'labels' int64 context feature.")
看代碼應該知道,這里介紹設置數據集的指向地址,注意地址的分隔符是'' /? ” , 其實“\\”也是可以的。
3.特征的名字(這一部分需要特別注意):
flags.DEFINE_string("feature_names", "audio_embedding", "Name of the feature ""to use for training.")
注意要設置成“audio_embedding”
4.幀特征設置
flags.DEFINE_bool("frame_features", True,"If set, then --train_data_pattern must be frame-level features. ""Otherwise, --train_data_pattern must be aggregated video-level ""features. The model must also be set appropriately (i.e. to read 3D ""batches VS 4D batches.")
5.最重要的就是模型的選擇設置
flags.DEFINE_string("model", "LstmModel","Which architecture to use for the model. Models are defined ""in models.py.")
主要要改成LstmModel,這個模型的設置其實你可以看frame_level_model.py和video_level_model.py兩個文件定義的模型,
我們使用音頻的幀特征肯定就要在frame_level_model.py選取模型。
?