EMNIST数据集怎么读取啊

AIanswer001 2019-04-02 12:24:07
在tensorflow下用读取mnist数据集的方法读emnist数据集总是显示多了几个标签
...全文
978 3 打赏 收藏 转发到动态 举报
AI 作业
写回复
用AI写文章
3 条回复
切换为时间正序
请发表友善的回复…
发表回复
Lijiatu_pro 2021-04-08
  • 打赏
  • 举报
回复
我是训练数据总比标签多。。
team-704 2019-07-19
  • 打赏
  • 举报
回复

import numpy as np
import matplotlib.pyplot as plt
import gzip

"""
just look the "parse_emnist" API and "load_data" API

parse_emnist: parse the emnist file with extension -'.gz' to .npz
load_data: load the .npz (the data structure is same as the 'mnist' dataset of keras)

"""


def read_idx3(filename):
    """
    Read the given file with its name
    :param filename: extension name of the file is '.gz'
    :return: images data, shape -> num, rows, cols
    """
    with gzip.open(filename) as fo:
        print('Reading images')
        buf = fo.read()

        offset = 0  # 指针
        header = np.frombuffer(buf, '>i', 4, offset)
        magic_number, num_images, num_rows, num_cols = header
        print("magic number: {}, \nnumber of images: {},\nnumber of rows: {}, \nnumber of columns: {}" \
              .format(magic_number, num_images, num_rows, num_cols))

        offset += header.size * header.itemsize
        data = np.frombuffer(buf, '>B', num_images * num_rows * num_cols, offset).reshape(
            (num_images, num_rows, num_cols))

        return data


def read_idx1(filename):
    """
    Read the given file with its name
    :param filename: extension name of the file is '.gz'
    :return: labels
    """
    with gzip.open(filename) as fo:
        print('Reading labels')
        buf = fo.read()

        offset = 0
        header = np.frombuffer(buf, '>i', 2, offset)
        magic_number, num_labels = header
        print("magic number: {}, \nnumber of labels: {}" \
              .format(magic_number, num_labels))

        offset += header.size * header.itemsize

        data = np.frombuffer(buf, '>B', num_labels, offset)
        return data


def show(images, labels, letter_mapping, window=(3, 4)):
    fig, axes = plt.subplots(*window, figsize=(15, 15))
    #     fig.set_figheight(15)
    #     fig.set_figwidth(15)

    for row in range(window[0]):
        for column in range(window[1]):
            ind = window[1] * row + column
            x = images[ind]
            y = labels[ind]

            axes[row][column].imshow(x, cmap=plt.cm.gray)
            axes[row][column].set_title(letter_mapping[y], fontsize=30)
            axes[row][column].axis('off')
            plt.tight_layout()

    plt.show()


def parse_emnist(output_path, *arg):
    """
    :param output_path:
    :param arg: four path -> train_img, train_label, test_img, test_label
    :return: None
    """
    train_img_path, train_label_path, test_img_path, test_label_path = arg
    train_img_data = read_idx3(train_img_path)
    train_label_data = read_idx1(train_label_path)
    test_img_data = read_idx3(test_img_path)
    test_label_data = read_idx1(test_label_path)

    if np.min(train_label_data) == 1:
        train_label_data = train_label_data - 1
        test_label_data = test_label_data - 1

    np.savez(output_path,
             x_train=train_img_data,
             y_train=train_label_data,
             x_test=test_img_data,
             y_test=test_label_data)
    print('Congratulations, your job has been done! Go to have a rest.')


def load_data(path):
    """Loads the EMNIST dataset.

    # Arguments
        path: path where to load the dataset

    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)





qq_41418522 2019-04-11
  • 打赏
  • 举报
回复
+1也是不知道这东西怎么加入神经网络

37,743

社区成员

发帖
与我相关
我的任务
社区描述
JavaScript,VBScript,AngleScript,ActionScript,Shell,Perl,Ruby,Lua,Tcl,Scala,MaxScript 等脚本语言交流。
社区管理员
  • 脚本语言(Perl/Python)社区
  • WuKongSecurity@BOB
加入社区
  • 近7日
  • 近30日
  • 至今

试试用AI创作助手写篇文章吧