UOMOP

Image PSNR Check_20230508 본문

Wireless Comm./Python

Image PSNR Check_20230508

Happy PinGu 2023. 5. 8. 19:56
def PSNR(ori_img, con_img):

  max_pixel = 255.0
  mse = np.mean((ori_img - con_img)**2)

  if mse ==0:
    return 100
  
  psnr = 20* math.log10(max_pixel / math.sqrt(mse))
  
  return round(psnr, 2)



def img2bit(path):

    gray_bit = list()
    img_gray = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    img_gray_flat = img_gray.flatten()

    for i in range(len(img_gray_flat)):
        gray_bit.append(format(img_gray_flat[i], 'b').zfill(8))
    gray_stream = ''.join(gray_bit)

    gray_output = list()

    for i in range(len(gray_stream)):
        gray_output.append(int(gray_stream[i]))

    return gray_output


#######fix_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def bit2img(input_sig):

    print(input_sig)

    channel = input_sig

    str_1 = list()

    for i in range(len(channel)):
        str_1.append(str(channel[i]))

    pixel_1 = list()

    for i in range(int(len(channel) / 8)):
        pixel_1.append(int(''.join(str_1[8 * i:8 * (i + 1)]), 2))

    array = np.reshape((np.array(pixel_1, dtype=np.uint8)), (int(sqrt(len(pixel_1))), -1))
    

    return array


def Compare_SNR(origin_img_name, path, max_snr, mod_mode, chcode_mode, rayleigh):
    title_list = list(["original"])
    BER_list = list([0])

    image_path = path
    gray_stream = img2bit(image_path)

    for i in range(max_snr):
        SNR = i + 1
        title_input = "SNR=" + str(SNR)
        title_list.append(title_input)

        input_sig, output_sig, BER = forward(gray_stream, SNR=i, mod_mode=mod_mode, rayleigh=rayleigh,
                                                chcode_mode=chcode_mode)
        BER_list.append(BER)

        image_arr = bit2img(input_sig=output_sig)
        img = Image.fromarray(image_arr)
        img.save(title_input + ".jpg")
        print("Complete : SNR = " + str(SNR))

    plt.figure(figsize=(22, 6))

    plt.subplot(1, max_snr + 1, 1)
    img = cv2.imread(origin_img_name, cv2.IMREAD_GRAYSCALE)
    plt.imshow(img)
    plt.title(title_list[0])

    for i in range(max_snr):
        plt.subplot(1, max_snr + 1, i + 2)
        img = cv2.imread(title_list[i + 1] + ".jpg")
        plt.imshow(img)
        titles = title_list[i + 1] + "  (BER=" + str(round(BER_list[i + 1], 4)) + ")"
        plt.title(titles)


def Bit_Gen(how_many):
    return (np.random.randint(2, size=how_many)).tolist()


def BER_Check(list_1, list_2):

    if (len(list_1) != len(list_2)): print("Lengths of the two streams are different.")
    error = 0
    print("BER ERROR Checking start")
    for i in trange(len(list_1)):
        if list_1[i] != list_2[i]:
            error += 1

    BER = round((error / len(list_1)), 7)
    #print("BER : {}".format(BER))
    print()

    return BER


class Tx:
    def Block_Encode(input_sig):

        n = 6
        k = 3

        redund = int(len(input_sig)) % 3
        padding_num_Block = k - redund

        for i in range(padding_num_Block):
            input_sig.append(0)

        Gen_Mat = np.matrix([[1, 1, 0, 1, 0, 0],
                             [0, 1, 1, 0, 1, 0],
                             [1, 0, 1, 0, 0, 1]])

        input_sig_list = list()
        output_list = list()
        result = list()

        for i in range(int(len(input_sig) / k)):
            for j in range(k):
                input_sig_list.append(int(input_sig[i * k + j]))
        print("(6, 3) Linear Block Coding start")
        for i in trange(int(len(input_sig_list) / k)):
            Data_Mat = np.matrix(input_sig_list[i * k: i * k + k])

            a = (Data_Mat * Gen_Mat).A1

            for j in range(len(a)):
                if a[j] % 2 == 0:
                    a[j] = 0

                elif (a[j] != 1) and (a[j] % 2 == 1):
                    a[j] = 1
                output_list.append(a[j])

        for i in range(padding_num_Block):
            input_sig.pop()

        print()

        return output_list, padding_num_Block



    def LDPC_Encode(input_sig, coding_rate):

        n = coding_rate[0]

        if coding_rate == (12, 9) :
            d_c = 6
            d_v = 2
        elif coding_rate == (14, 11) :
            d_c = 7
            d_v = 2
        elif coding_rate == (10, 7) :
            d_c = 5
            d_v = 2
        elif coding_rate == (15, 10) :
            d_c = 5
            d_v = 2
        elif coding_rate == (14, 10) :
            d_c = 7
            d_v = 3

        snr = 10000000000000
        H, G = make_ldpc(n, d_v, d_c, systematic=True, sparse=True)
        k = G.shape[1]

        print("({}, {})LDPC Channel Coding".format(n, k))

        redund = int(len(input_sig)) % k
        padding_num_ldpc = k - redund

        for i in range(padding_num_ldpc):
            input_sig.append(0)

        coded_list = list()

        input_arr = np.array(input_sig)

        print("{} LDPC Channel Coding start".format(coding_rate))

        for i in trange(int(len(input_arr) / k)):

            result = encode(G, input_arr[k * i: k * (i + 1)], snr).tolist()

            for j in range(len(result)):
                coded_list.append(result[j])

        for i in range(len(coded_list)):
            coded_list[i] = int(coded_list[i])
            if coded_list[i] < 0:
                coded_list[i] = int(0)

        for i in range(padding_num_ldpc):
            input_sig.pop()

        print()

        return coded_list, padding_num_ldpc



    def Modulation(input_sig, mode):
        # mode : {M-ASK, M-FSK, M-PSK}

        if mode[-3] == 'P':

            M = int(mode[0])
            k = int(math.log2(M))

            redund = int(len(input_sig)) % k
            padding_num_for_psk = k - redund

            for i in range(padding_num_for_psk):
                input_sig.append(0)

            start_phase = np.pi * (1 / M)

            phase_list = list([start_phase])
            phase_step = start_phase * 2

            for i in range(int(M / 2) - 1):
                start_phase = phase_list[i] + phase_step
                phase_list.append(start_phase)

            phase_list_rev = list(reversed(phase_list))

            for i in range(len(phase_list_rev)):
                phase_list_rev[i] *= -1

            phase_list = phase_list + phase_list_rev

            result = list()
            table_for_demod = list()

            for i in range(M):
                inphase = np.cos(phase_list[i])
                quadrature = np.sin(phase_list[i])

                table_for_demod.append(complex(inphase, quadrature))

            print("{} Modulation start".format(mode))

            for i in trange(int(len(input_sig) / k)):
                bin_sig = "".join(map(str, input_sig[k * i: k * i + k]))
                # print('0b' + bin_sig)

                sig = int('0b' + bin_sig, 2)

                inphase = np.cos(phase_list[sig])
                quadrature = np.sin(phase_list[sig])

                result.append(complex(inphase, quadrature))

            for i in range(padding_num_for_psk):
                input_sig.pop()

            print()

            return result, padding_num_for_psk


        elif mode[-3] == 'Q':

            if mode[0] == '1':
                M = 16
            elif mode[0] == '6':
                M = 64

            k = int(math.log2(M))

            redund = int(len(input_sig)) % k
            padding_num_for_qam = k - redund

            for i in range(padding_num_for_qam):
                input_sig.append(0)

            list_gray = list()
            list_bbb = list()
            output_sig = list()

            for i in range(int(sqrt(M))):
                list_gray.append(format(graycode.tc_to_gray_code(i), 'b').zfill(int(k / 2)))

            for i in range(int(sqrt(M))):
                start_num = 1 - int(sqrt(M))
                list_bbb.append(start_num + 2 * i)

            print("{} Modulation start".format(mode))

            for i in trange(int(len(input_sig) / k)):
                cut_sig = input_sig[k * i: k * (i + 1)]
                inphase_val = str("")
                quadrature_val = str("")

                for j in range(int(len(cut_sig) / 2)):
                    inphase_val = inphase_val + str(cut_sig[j])
                    quadrature_val = quadrature_val + str(cut_sig[j + int(len(cut_sig) / 2)])

                check_inphase = 0
                check_quadrature = 0

                for t in range(int(sqrt(M))):
                    if inphase_val == list_gray[t]:
                        inphase = list_bbb[t]
                    if quadrature_val == list_gray[t]:
                        quadrature = list_bbb[t]

                output_sig.append(complex(inphase, quadrature))

            for i in range(padding_num_for_qam):
                input_sig.pop()

            print()

            return output_sig, padding_num_for_qam


class Fading_Channel:
    def rayleigh(input_sig):
        output = list()

        for i in range(len(input_sig)):
            channel_coef = complex(random.gauss(0, 1), random.gauss(0, 1)) / sqrt(2)

            output.append(abs(channel_coef) * input_sig[i])

        return output


class Noise:
    def AWGN(input_sig, SNR_dB):

        target_snr_db = SNR_dB

        sum_abs = 0
        abs_list = list()
        noise_out = list()

        for i in range(len(input_sig)):
            sum_abs = sum_abs + abs(input_sig[i])

        for i in range(len(input_sig)):
            abs_list.append(abs(input_sig[i]))

        SNR_linear = 10 ** (SNR_dB / 10)

        input_sig_power = np.mean(abs_list)

        N0 = input_sig_power / SNR_linear

        for i in range(len(input_sig)):
            noise = sqrt(N0 / 2) * complex(random.gauss(0, 1), random.gauss(0, 1))
            noise_out.append(input_sig[i] + noise)

        return noise_out


class Rx:
    def DeModulation(input_sig, mode, padding_num_for_demod):

        if mode[2] == 'P':
            M = int(mode[0])
            k = int(math.log2(M))
            bit_list = list()

            start_phase = np.pi * (1 / M)

            phase_list = list([start_phase])
            phase_step = start_phase * 2

            for i in range(int(M / 2) - 1):
                start_phase = phase_list[i] + phase_step
                phase_list.append(start_phase)

            phase_list_rev = list(reversed(phase_list))

            for i in range(len(phase_list_rev)):
                phase_list_rev[i] *= -1

            phase_list = phase_list + phase_list_rev

            result = list()
            table_for_demod = list()

            for i in range(M):
                inphase = np.cos(phase_list[i])
                quadrature = np.sin(phase_list[i])

                table_for_demod.append(complex(inphase, quadrature))

            print("{} DeModulation start".format(mode))

            for i in trange(len(input_sig)):
                input_sig_x = float(input_sig[i].real)
                input_sig_y = float(input_sig[i].imag)

                distance = 1000
                output = 0

                for j in range(M):
                    if distance > sqrt(
                            (input_sig_x - table_for_demod[j].real) ** 2 + (
                                    input_sig_y - table_for_demod[j].imag) ** 2):
                        distance = sqrt(
                            (input_sig_x - table_for_demod[j].real) ** 2 + (input_sig_y - table_for_demod[j].imag) ** 2)
                        output = j

                bin_output = format(output, 'b').zfill(k)

                for q in range(k):
                    bit_list.append(int(bin_output[q]))

            for i in range(padding_num_for_demod):
                bit_list.pop()
            print()

            return bit_list


        elif mode[-3] == 'Q':

            if mode[0] == '1':
                M = 16
            elif mode[0] == '6':
                M = 64

            k = int(math.log2(M))

            list_gray = list()
            list_bbb = list()
            output_sig = list()
            all_candidate = list()
            demodulated_list = list()

            for i in range(int(sqrt(M))):
                list_gray.append(format(graycode.tc_to_gray_code(i), 'b').zfill(int(k / 2)))

            for i in range(int(sqrt(M))):
                start_num = 1 - int(sqrt(M))
                list_bbb.append(start_num + 2 * i)

            for i in range(len(list_bbb)):
                for j in range(len(list_bbb)):
                    all_candidate.append(complex(list_bbb[i], list_bbb[j]))

            print("{} DeModulation start".format(mode))

            for i in trange(len(input_sig)):

                d_min = 100
                save_index = 0

                for j in range(len(all_candidate)):
                    if sqrt(((float(input_sig[i].real) - float(all_candidate[j].real)) ** 2) + (
                            (float(input_sig[i].imag) - float(all_candidate[j].imag)) ** 2)) < d_min:
                        d_min = sqrt(((float(input_sig[i].real) - float(all_candidate[j].real)) ** 2) + (
                                    (float(input_sig[i].imag) - float(all_candidate[j].imag)) ** 2))
                        inphase_val = int(all_candidate[j].real)
                        quadrature_val = int(all_candidate[j].imag)

                for k in range(len(list_bbb)):
                    if (inphase_val == list_bbb[k]):
                        output_first = list_gray[k]

                    if (quadrature_val == list_bbb[k]):
                        output_second = list_gray[k]

                result = output_first + output_second

                for q in range(len(result)):
                    demodulated_list.append(int(result[q]))

            for i in range(padding_num_for_demod):
                demodulated_list.pop()

            print()

            return demodulated_list



    def Block_Decode(input_sig, padding_num_Block):

        n = 6
        k = 3
        Parity_Check_Mat = np.matrix([[1, 0, 0],
                                      [0, 1, 0],
                                      [0, 0, 1],
                                      [1, 1, 0],
                                      [0, 1, 1],
                                      [1, 0, 1]])

        syndrom_table = np.matrix([[0, 0, 0, 0, 0, 0],
                                   [0, 0, 0, 0, 0, 1],
                                   [0, 0, 0, 0, 1, 0],
                                   [0, 0, 0, 1, 0, 0],
                                   [0, 0, 1, 0, 0, 0],
                                   [0, 1, 0, 0, 0, 0],
                                   [1, 0, 0, 0, 0, 0],
                                   [0, 1, 0, 0, 0, 1]])

        output_list = list()

        print("(6, 3) Linear Block Decoding start")

        for i in trange(int(len(input_sig) / n)):
            r_list = list()
            cut_input_sig = input_sig[n * i: n * i + n]

            for j in range(len(cut_input_sig)):
                r_list.append(cut_input_sig[j])

            r_Mat = np.matrix(r_list)
            syndrom = r_Mat * Parity_Check_Mat
            sum_syn = 0

            for w in range(3):
                if syndrom[0, w] % 2 == 0:
                    syndrom[0, w] = 0
                elif (syndrom[0, w] != 1) and (syndrom[0, w] % 2 == 1):
                    syndrom[0, w] = 1

            for q in range(3):
                sum_syn += syndrom[0, q] * (2 ** q)

            error_cor = np.matrix(cut_input_sig) + syndrom_table[sum_syn]

            for p in range(k):
                if error_cor[0, p + 3] % 2 == 0:
                    error_cor[0, p + 3] = 0
                elif (error_cor[0, p + 3] != 1) and (error_cor[0, p + 3] % 2 == 1):
                    error_cor[0, p + 3] = 1

                output_list.append(error_cor[0, p + 3])

        for i in range(padding_num_Block):
            output_list.pop()

        print()

        return output_list



    def LDPC_Decode(input_sig, coding_rate, padding_num_ldpc):

        n = coding_rate[0]

        if coding_rate == (12, 9) :
            d_c = 6
            d_v = 2
        elif coding_rate == (14, 11) :
            d_c = 7
            d_v = 2
        elif coding_rate == (10, 7) :
            d_c = 5
            d_v = 2
        elif coding_rate == (15, 10) :
            d_c = 5
            d_v = 2
        elif coding_rate == (14, 10) :
            d_c = 7
            d_v = 3

        for i in range(len(input_sig)):
            input_sig[i] = float(input_sig[i])
            if input_sig[i] == 0:
                input_sig[i] = float(-1)

        snr = 10000000000000
        H, G = make_ldpc(n, d_v, d_c, systematic=True, sparse=True)
        k = G.shape[1]

        input_sig = np.array(input_sig)

        decoded_list = list()

        print("{} LDPC Decoding start".format(coding_rate))

        for i in trange(int(len(input_sig) / n)):
            result = decode(H, input_sig[n * i: n * (i + 1)], snr)
            restored_msg = get_message(G, result)

            for j in range(len(restored_msg)):
                decoded_list.append(restored_msg[j])

        for i in range(padding_num_ldpc):
            decoded_list.pop()

        print()

        return decoded_list


def forward(input_sig, SNR, mod_mode, rayleigh, chcode_mode):
    error = 0

    if chcode_mode == "LDPC":
        encoded_sig, padding_num_LDPC = Tx.LDPC_Encode(input_sig, (12, 9))
    elif chcode_mode == "Block":
        encoded_sig, padding_num_Block = Tx.Block_Encode(input_sig)
    print("Encoded signal's length : {}".format(len(encoded_sig)))

    modulated_sig, padding_num_demod = Tx.Modulation(input_sig=encoded_sig, mode=mod_mode)
    print("Modulated signal({})'s length : {}".format(mod_mode, len(modulated_sig)))

    if rayleigh == "Yes":
        AWGN_input = Fading_Channel.rayleigh(modulated_sig)
    elif rayleigh == "No":
        AWGN_input = modulated_sig

    AWGN_output = Noise.AWGN(AWGN_input, SNR_dB=SNR)
    demodulated_sig = Rx.DeModulation(AWGN_output, mod_mode, padding_num_demod)
    print("Demodulated signal({})'s length : {}".format(mod_mode, len(demodulated_sig)))

    if chcode_mode == "LDPC":
        decoded_sig = Rx.LDPC_Decode(demodulated_sig, (12, 9), padding_num_LDPC)
    elif chcode_mode == "Block":
        decoded_sig = Rx.Block_Decode(demodulated_sig, padding_num_Block)
    print("Decoded signal's length : {}".format(len(decoded_sig)))

    BER = BER_Check(input_sig, decoded_sig)

    print("BER : {}".format(BER))

    return input_sig, decoded_sig, BER


class BER_Graph:
    def Compare_rayleigh(input_sig, snr_range, snr_step, mod_mode, chcode_mode):

        BER_list = list()
        BER_list_Ray = list()
        SNR_list = list()

        Block_level = 100  # 임의의 정수로 잡아둔다

        num_bit = len(input_sig)

        print("Start the code without rayleigh")

        for i in np.arange(0, snr_range + 1, snr_step):
            error = 0
            input_sig, decoded_sig, BER = forward(input_sig=input_sig, SNR=i, mod_mode=mod_mode, rayleigh="No",
                                                  chcode_mode=chcode_mode)
            BER_list.append(BER)
            SNR_list.append(i)
            print("SNR = {}, Complete!\tBER : {}".format(i, BER))

        print("\n\nStart the code with rayleigh")

        for i in np.arange(0, snr_range + 1, snr_step):
            error = 0
            input_sig, decoded_sig, BER = forward(input_sig=input_sig, SNR=i, mod_mode=mod_mode, rayleigh="Yes",
                                                  chcode_mode=chcode_mode)
            BER_list_Ray.append(BER)
            print("SNR = {}, Complete!\tBER : {}".format(i, BER))

        plt.plot(SNR_list, BER_list, marker='o', linestyle='dashed', color='blue', label='No rayleigh')
        plt.plot(SNR_list, BER_list_Ray, marker='x', linestyle='dotted', color='red', label='with rayleigh')
        plt.axis([0, SNR_list[-1] + 1, 1e-6, 1])
        plt.xscale('linear')
        plt.yscale('log')
        plt.xlabel('EbNo(dB)')
        plt.ylabel('BER')

        plt.grid(True)
        plt.legend()
        plt.show()

    def Compare_chcode(input_sig, snr_range, snr_step, mod_mode, rayleigh):

        BER_list_Block = list()
        BER_list_LDPC = list()
        SNR_list = list()

        print("Start the code with LDPC")

        for i in np.arange(0, snr_range + 1, snr_step):
            error = 0
            input_sig, decoded_sig, BER = forward(input_sig=input_sig, SNR=i, mod_mode=mod_mode, rayleigh=rayleigh,
                                                  chcode_mode="LDPC")
            BER_list_LDPC.append(BER)
            SNR_list.append(i)
            print("SNR = {}, Complete!\tBER : {}".format(i, BER))

        print("\n\nStart the code with Linear Block Coding")

        for i in np.arange(0, snr_range + 1, snr_step):
            error = 0
            input_sig, decoded_sig, BER = forward(input_sig=input_sig, SNR=i, mod_mode=mod_mode, rayleigh=rayleigh,
                                                  chcode_mode="Block")
            BER_list_Block.append(BER)
            print("SNR = {}, Complete!\tBER : {}".format(i, BER))

        plt.plot(SNR_list, BER_list_LDPC, marker='o', linestyle='dashed', color='blue', label='LDPC')
        plt.plot(SNR_list, BER_list_Block, marker='x', linestyle='dotted', color='red', label='Block Coding')
        plt.axis([0, SNR_list[-1] + 1, 1e-6, 1])
        plt.xscale('linear')
        plt.yscale('log')
        plt.xlabel('EbNo(dB)')
        plt.ylabel('BER')

        plt.grid(True)
        plt.legend()
        plt.show()
ori_img = cv2.imread("/content/lena.png", cv2.IMREAD_GRAYSCALE)
plt.imshow(ori_img, cmap = 'gray')

gray_img_bit = img2bit("/content/lena.png")
print(len(gray_img_bit))
input_sig, decoded_sig, BER = forward(gray_img_bit, 5, "4-PSK", "No", "Block")
gray_img_arr = bit2img(decoded_sig)
com_img = gray_img_arr
plt.imshow(com_img, cmap = 'gray')

gray_img_bit = img2bit("/content/lena.png")
print(len(gray_img_bit))
input_sig, decoded_sig, BER = forward(gray_img_bit, 10, "4-PSK", "No", "Block")
gray_img_arr = bit2img(decoded_sig)
com_img_10 = gray_img_arr
plt.imshow(com_img_10, cmap = 'gray')

snr_5 = PSNR(ori_img, com_img)
print("PSNR(SNR = 5)  : {}".format(snr_5))
snr_10 = PSNR(ori_img, com_img_10)
print("PSNR(SNR = 10) : {}".format(snr_10))

'Wireless Comm. > Python' 카테고리의 다른 글

plt.imshow 파란색으로 되는 현상  (0) 2023.05.24
AutoEncoder for Fashion MNIST  (0) 2023.05.11
Conventional_20230504  (0) 2023.05.04
save2  (0) 2023.05.03
LPDC (n, k) comparison  (0) 2023.05.03
Comments