import math
import os
import shutil
from collections import Counter

data_dir = 'DEMO/Data/Dogcls'
label_file = 'labels.csv'
train_dir = 'train'
test_dir = 'test'
valid_dir = 'valid'
input_str = 'DEMO/Data/Dogcls/train_valid_test'
input_dir = 'train_valid_test'
batch_size = 128
valid_ratio = 0.1

def reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,valid_ratio):
    # 读取训练数据标签。
    with open(os.path.join(data_dir, label_file), 'r') as f:
        # 跳过文件头行(栏名称)。
        lines = f.readlines()[1:]
        tokens = [l.rstrip().split(',') for l in lines]
        idx_label = dict(((idx, label) for idx, label in tokens))
    labels = set(idx_label.values())

    num_train = len(os.listdir(os.path.join(data_dir, train_dir)))
    # 训练集中数量最少一类的狗的数量。Counter返回一个字典,most_common对其进行排序
    min_num_train_per_label = (
        Counter(idx_label.values()).most_common()[:-2:-1][0][1])
    # 验证集中每类狗的数量。math.floor:一个表示小于或等于指定数字的最大整数的数字。
    num_valid_per_label = math.floor(min_num_train_per_label * valid_ratio)
    label_count = dict()

    def mkdir_if_not_exist(path):
        if not os.path.exists(os.path.join(*path)):
            os.makedirs(os.path.join(*path))

    # 整理训练和验证集。
    for train_file in os.listdir(os.path.join(data_dir, train_dir)):
        idx = train_file.split('.')[0]
        label = idx_label[idx]
        mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label])
        shutil.copy(os.path.join(data_dir, train_dir, train_file),
                    os.path.join(data_dir, input_dir, 'train_valid', label))
        if label not in label_count or label_count[label] < num_valid_per_label:
            mkdir_if_not_exist([data_dir, input_dir, 'valid', label])
            shutil.copy(os.path.join(data_dir, train_dir, train_file),
                        os.path.join(data_dir, input_dir, 'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            mkdir_if_not_exist([data_dir, input_dir, 'train', label])
            shutil.copy(os.path.join(data_dir, train_dir, train_file),
                        os.path.join(data_dir, input_dir, 'train', label))

    # 整理测试集。
    mkdir_if_not_exist([data_dir, input_dir, 'test', 'unknown'])
    for test_file in os.listdir(os.path.join(data_dir, test_dir)):
        shutil.copy(os.path.join(data_dir, test_dir, test_file),
                    os.path.join(data_dir, input_dir, 'test', 'unknown'))


reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,
                   valid_ratio)


from mxnet import gluon
from mxnet import image
import numpy as np
from mxnet import nd

def transform_train(data, label):
    im1 = image.imresize(data.astype('float32') / 255, 224, 224)
    im2 = image.imresize(data.astype('float32') / 255, 299, 299)
    auglist1 = image.CreateAugmenter(data_shape=(3, 224, 224), resize=0, 
                        rand_crop=False, rand_resize=False, rand_mirror=True,
                        mean=np.array([0.485, 0.456, 0.406]), std=np.array([0.229, 0.224, 0.225]), 
                        brightness=0, contrast=0, 
                        saturation=0, hue=0, 
                        pca_noise=0, rand_gray=0, inter_method=2)
    auglist2 = image.CreateAugmenter(data_shape=(3, 299, 299), resize=0, 
                        rand_crop=False, rand_resize=False, rand_mirror=True,
                        mean=np.array([0.485, 0.456, 0.406]), std=np.array([0.229, 0.224, 0.225]), 
                        brightness=0, contrast=0, 
                        saturation=0, hue=0, 
                        pca_noise=0, rand_gray=0, inter_method=2)
    for aug in auglist1:
        im1 = aug(im1)
    for aug in auglist2:
        im2 = aug(im2)
    # 将数据格式从"高*宽*通道"改为"通道*高*宽"。
    im1 = nd.transpose(im1, (2,0,1))
    im2 = nd.transpose(im2, (2,0,1))
    return (im1,im2, nd.array([label]).asscalar().astype('float32'))

def transform_test(data, label):
    im1 = image.imresize(data.astype('float32') / 255, 224, 224)
    im2 = image.imresize(data.astype('float32') / 255, 299, 299)
    auglist1 = image.CreateAugmenter(data_shape=(3, 224, 224),
                        mean=np.array([0.485, 0.456, 0.406]), 
                        std=np.array([0.229, 0.224, 0.225]))
    auglist2 = image.CreateAugmenter(data_shape=(3, 299, 299),
                        mean=np.array([0.485, 0.456, 0.406]), 
                        std=np.array([0.229, 0.224, 0.225]))
    for aug in auglist1:
        im1 = aug(im1)
    for aug in auglist2:
        im2 = aug(im2)
    # 将数据格式从"高*宽*通道"改为"通道*高*宽"。
    im1 = nd.transpose(im1, (2,0,1))
    im2 = nd.transpose(im2, (2,0,1))
    return (im1,im2, nd.array([label]).asscalar().astype('float32'))

batch_size = 32

train_ds = gluon.data.vision.ImageFolderDataset(input_str + train_dir, flag=1,
                                      transform=transform_train)
valid_ds = gluon.data.vision.ImageFolderDataset(input_str+ valid_dir, flag=1,
                                      transform=transform_test)
train_valid_ds = gluon.data.vision.ImageFolderDataset(input_str+ train_valid_dir,
                                           flag=1, transform=transform_train)
test_ds = gluon.data.vision.ImageFolderDataset(input_str + test_dir, flag=1,
                                      transform=transform_test)

loader = gluon.data.DataLoader
train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
valid_data = loader(valid_ds, batch_size, shuffle=True, last_batch='keep')
train_valid_data = loader(train_valid_ds, batch_size, shuffle=True,
                          last_batch='keep')

from mxnet import init
from mxnet.gluon import nn

class  ConcatNet(nn.HybridBlock):
    def __init__(self,net1,net2,**kwargs):
        super(ConcatNet,self).__init__(**kwargs)
        self.net1 = nn.HybridSequential()
        self.net1.add(net1)
        self.net1.add(nn.GlobalAvgPool2D())
        self.net2 = nn.HybridSequential()
        self.net2.add(net2)
        self.net2.add(nn.GlobalAvgPool2D())
    def hybrid_forward(self,F,x1,x2):
        return F.concat(*[self.net1(x1),self.net2(x2)])
    
def get_features2(ctx):
    inception = gluon.model_zoo.vision.inception_v3(pretrained=True,ctx=ctx)
    return inception.features

def get_features1(ctx):
    resnet = gluon.model_zoo.vision.resnet152_v1(pretrained=True,ctx=ctx)
    return resnet.features

def get_features(ctx):
    features1 = get_features1(ctx)
    features2 = get_features2(ctx)
    net = ConcatNet(features1,features2)
    return net    
    
def get_output(ctx,ParamsName=None):
    net = nn.HybridSequential()
    with net.name_scope():
        net.add(nn.Dense(256, activation="relu"))
        net.add(nn.Dropout(.7))
        net.add(nn.Dense(120))
    if ParamsName is not None:
        net.collect_params().load(ParamsName,ctx)
    else:
        net.initialize(init = init.Xavier(),ctx=ctx)
    return net    
    
class  OneNet(nn.HybridBlock):
    def __init__(self,features,output,**kwargs):
        super(OneNet,self).__init__(**kwargs)
        self.features = features
        self.output = output
    def hybrid_forward(self,F,x1,x2):
        return self.output(self.features(x1,x2))    
    
def get_net(ParamsName,ctx):
    output = get_output(ctx,ParamsName)
    features = get_features(ctx)
    net = OneNet(features,output)
    return net  
    
from tqdm import tqdm
import datetime
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from mxnet import autograd
import mxnet as mx
import pickle   
    
net = get_features(mx.gpu())
net.hybridize()

def SaveNd(data,net,name):
    x =[]
    y =[]
    print('提取特征 %s' % name)
    for fear1,fear2,label in tqdm(data):
        fear1 = fear1.as_in_context(mx.gpu())
        fear2 = fear2.as_in_context(mx.gpu())
        out = net(fear1,fear2).as_in_context(mx.cpu())
        x.append(out)
        y.append(label)
    x = nd.concat(*x,dim=0)
    y = nd.concat(*y,dim=0)
    print('保存特征 %s' % name)
    nd.save(name,[x,y])

SaveNd(train_data,net,'train_r152i3.nd')
SaveNd(valid_data,net,'valid_r152i3.nd')
SaveNd(train_valid_data,net,'input_r152i3.nd')    
    
ids = ids = sorted(os.listdir(os.path.join(data_dir, input_dir, 'test/unknown')))
synsets = train_valid_ds.synsets
f = open('ids_synsets','wb')
pickle.dump([ids,synsets],f)
f.close()    

num_epochs = 100
batch_size = 128
learning_rate = 1e-4
weight_decay = 1e-4
pngname='train.png'
modelparams='r152i3.params'

train_nd = nd.load('train_r152i3.nd')
valid_nd = nd.load('valid_r152i3.nd')
input_nd = nd.load('input_r152i3.nd')
f = open('ids_synsets','rb')
ids_synsets = pickle.load(f)
f.close()

train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(train_nd[0],train_nd[1]), batch_size=batch_size,shuffle=True)
valid_data = gluon.data.DataLoader(gluon.data.ArrayDataset(valid_nd[0],valid_nd[1]), batch_size=batch_size,shuffle=True)
input_data = gluon.data.DataLoader(gluon.data.ArrayDataset(input_nd[0],input_nd[1]), batch_size=batch_size,shuffle=True)

def get_loss(data, net, ctx):
    loss = 0.0
    for feas, label in data:
        label = label.as_in_context(ctx)
        output = net(feas.as_in_context(ctx))
        cross_entropy = softmax_cross_entropy(output, label)
        loss += nd.mean(cross_entropy).asscalar()
    return loss / len(data)

def train(net, train_data, valid_data, num_epochs, lr, wd, ctx):
    trainer = gluon.Trainer(
        net.collect_params(), 'adam', {'learning_rate': lr, 'wd': wd})
    train_loss = []
    if valid_data is not None:
        test_loss = []
    
    prev_time = datetime.datetime.now()
    for epoch in range(num_epochs):
        _loss = 0.
        for data, label in train_data:
            label = label.as_in_context(ctx)
            with autograd.record():
                output = net(data.as_in_context(ctx))
                loss = softmax_cross_entropy(output, label)
            loss.backward()
            trainer.step(batch_size)
            _loss += nd.mean(loss).asscalar()
        cur_time = datetime.datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        __loss = _loss/len(train_data)
        train_loss.append(__loss)
        
        if valid_data is not None:  
            valid_loss = get_loss(valid_data, net, ctx)
            epoch_str = ("Epoch %d. Train loss: %f, Valid loss %f, "
                         % (epoch,__loss , valid_loss))
            test_loss.append(valid_loss)
        else:
            epoch_str = ("Epoch %d. Train loss: %f, "
                         % (epoch, __loss))
            
        prev_time = cur_time
        print(epoch_str + time_str + ', lr ' + str(trainer.learning_rate))
        

    plt.plot(train_loss, 'r')
    if valid_data is not None: 
        plt.plot(test_loss, 'g')
    plt.legend(['Train_Loss', 'Test_Loss'], loc=2)


    plt.savefig(pngname, dpi=1000)
    net.collect_params().save(modelparams)

softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
ctx = mx.gpu()
net = get_output(ctx)
net.hybridize()

train(net, train_data,valid_data, num_epochs, learning_rate, weight_decay, ctx)

netparams = 'r152i3.params'
csvname = 'kaggle.csv'
ids_synsets_name = 'ids_synsets'
f = open(ids_synsets_name,'rb')
ids_synsets = pickle.load(f)
f.close()


test_ds = vision.ImageFolderDataset(data_dir + test_dir, flag=1,
                                     transform=transform_test)
def SaveTest(test_data,net,ctx,name,ids,synsets):
    outputs = []
    for data1,data2, label in tqdm(test_data):
        data1 =data1.as_in_context(ctx)
        data2 =data2.as_in_context(ctx)
        output = nd.softmax(net(data1,data2))
        outputs.extend(output.asnumpy())
    with open(name, 'w') as f:
        f.write('id,' + ','.join(synsets) + '\n')
        for i, output in zip(ids, outputs):
            f.write(i.split('.')[0] + ',' + ','.join(
                [str(num) for num in output]) + '\n')

net = get_net(netparams,mx.gpu())
net.hybridize()
SaveTest(test_data,net,mx.gpu(),csvname,ids_synsets[0],ids_synsets[1])