Cucco’s Compute Hack

コンピュータ関係の記事を書いていきます。

テーブルの定義とデータをファイルに。インポートとエクスポート。

テーブルの定義とデータをファイルに。インポートとエクスポート。

import sqlite3
import datetime
import os
import csv

import time # ダミーデータを作るときのsleep用。

class tabledata_file_converter():
    """テーブルの情報をファイルに書き出す、ファイルからテーブルを復元する
    """
    @staticmethod
    def get_table_definition_from_sqlite_master(db_cursor, table_name):
        # dbmsからテーブル定義の情報を取得する
        sql = f"select sql from sqlite_master where type='table' and name='{table_name}'"
        db_cursor.execute(sql)
        row = db_cursor.fetchone()
        if row is None:
            raise
        return row[0]
        
    @staticmethod
    def build_fullpath_table_def_file(table_name, folder_path=None):
        # テーブル定義を書き込むファイルのパスを作る
        post_fix="_table_def.txt"
        if folder_path is None:
            file_fullpath = os.path.dirname(os.path.abspath(__file__)) + os.sep + table_name + post_fix
        else:
            file_fullpath = folder_path + os.sep + table_name + post_fix
        return file_fullpath

    @staticmethod
    def build_fullpath_table_dat_csv(table_name, folder_path=None):
        # テーブルのデータを書き込むファイルのパスを作る
        post_fix="_table_dat.csv"
        if folder_path is None:
            file_fullpath = os.path.dirname(os.path.abspath(__file__)) + os.sep + table_name + post_fix
        else:
            file_fullpath = folder_path + os.sep + table_name + post_fix
        return file_fullpath

    @staticmethod
    def save_def_file(table_name, table_definition, folder_path=None):
        # テーブル定義をファイルに書き込む
        save_file_fullpath = tabledata_file_converter.build_fullpath_table_def_file(table_name, folder_path)

        with open(save_file_fullpath, mode='w', encoding='utf-8') as f:
            f.write(table_definition)

    @staticmethod
    def read_from_def_file(table_name, folder_path=None):
        # テーブル定義をファイルから読み出す。
        read_file_fullpath = tabledata_file_converter.build_fullpath_table_def_file(table_name, folder_path)

        if os.path.isfile(read_file_fullpath) is False:
            raise
        
        with open(read_file_fullpath, mode='r', encoding='utf-8') as f:
            row = f.readline()
        return row
    
    @staticmethod
    def is_exist_table(db_cursor, tale_name):
        #テーブルの存在確認
        db_cursor.execute(f'SELECT COUNT(*) FROM sqlite_master WHERE TYPE="table" AND NAME="{tale_name}"')
        if db_cursor.fetchone() == (0,): #存在しないとき
            return False
        else:
            return True

    @staticmethod
    def list_to_sql_str(list_data):
        '''
        list_to_sql_str(["hoge", 3, 0.14])
        "'hoge',3,0.14" 
        '''
        y=[]
        for item in list_data:
            if type(item) is str:
                y.append("'"+item+"'")
            else:
                y.append(str(item))
        
        return ','.join(y)

    @staticmethod
    def load_from_csv(db_cursor, table_name, folder_path=None):
        # ファイル(データと定義)からテーブルを復元する

        # テーブルの有無確認 テーブルがあればエラー終了させる
        if __class__.is_exist_table(db_cursor, table_name) is True:
            raise

        # テーブルを定義する
        table_def_sql = __class__.read_from_def_file(table_name)
        db_cursor.execute(table_def_sql)

        read_file_fullpath = __class__.build_fullpath_table_dat_csv(table_name, folder_path)

        f = open(read_file_fullpath,mode='r',encoding='utf-8', newline='')
        csv_reader = csv.reader(f, delimiter=',', quotechar='"', lineterminator='\n')

        for row in  csv_reader:
            insert_sql = f"INSERT INTO {table_name} VALUES({ __class__.list_to_sql_str(row)})"
            db_cursor.execute(insert_sql)
        
    @staticmethod
    def save_csv(db_cursor, table_name, folder_path=None):
        save_file_fullpath=tabledata_file_converter.build_fullpath_table_dat_csv(table_name=table_name, folder_path=folder_path)
        
        f = open(save_file_fullpath,mode='w',encoding='utf-8', newline='')
        csv_writer = csv.writer(f, delimiter=',', quotechar='"', lineterminator='\n')
        
        for row in db_cursor.execute(f'SELECT * FROM {table_name}'):
            csv_writer.writerow(row)

        f.close()


if __name__ == '__main__':
    con = sqlite3.connect(":memory:", check_same_thread=False)
    con.isolation_level = None

    cur = con.cursor()
    cur.execute('PRAGMA temp_store=MEMORY;')
    cur.execute('PRAGMA journal_mode=MEMORY;')

    # Create table
    cur.execute('CREATE TABLE stocks (date text, ts timestamp, trans text, symbol text, qty real, price real, add_col integer)')

    # ダミーデータを作る
    for i in range(100):
        # Insert a row of data
        now = datetime.datetime.now()
        cur.execute("INSERT INTO stocks VALUES ('2006-01-05',?,'BUY','RHAT',100,35.14,?)", (now, str(i)))
        time.sleep(0.01)

    # 新しいほうから3件だけを表示する。
    for row in cur.execute("SELECT * From stocks ORDER BY add_col DESC LIMIT 3"):
        print(row)

    # SQL定義の確認
    for row in cur.execute("select sql from sqlite_master where type='table' and name='stocks'"):
        print(row)

    table_definision = tabledata_file_converter.get_table_definition_from_sqlite_master(db_cursor=cur, table_name='stocks')
    print(table_definision)

    # テーブル定義の書き出し
    tabledata_file_converter.save_def_file(table_name='stocks', table_definition=table_definision)

    # テーブル定義の読み出し
    table_definision2 = tabledata_file_converter.read_from_def_file(table_name='stocks')
    print(table_definision2)

    # テーブルデータの書き出し
    tabledata_file_converter.save_csv(cur, "stocks")

    # テーブルの有無
    print(tabledata_file_converter.is_exist_table(cur,"stocks"))
    print(tabledata_file_converter.is_exist_table(cur, "stocksA"))

    # テーブルデータの復元(stocksA向けのファイルを事前に作っておく)
    tabledata_file_converter.load_from_csv(cur, "stocksA")

    # 復元されたデータの表示。新しいほうから3件だけを表示する。
    for row in cur.execute("SELECT * From stocksA ORDER BY add_col DESC LIMIT 3"):
        print(row)

    # 復元されたデータの表示。件数を表示する
    for row in cur.execute("SELECT COUNT(*) From stocksA ORDER BY add_col DESC LIMIT 3"):
        print(row)

    # SQL定義の確認
    for row in cur.execute("select sql from sqlite_master where type='table'"):
        print(row)