#!/usr/bin/env python3
# coding: utf8
# Interface graphique pour illustrer la convolution
# Licence libre WTFPL

import tkinter
import numpy as np
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
from matplotlib.pyplot import tight_layout

# Définition des signaux#############################################
Fe = 1e4   # approximation continue
N = 2*Fe   # nombre de points jusqu'à +- 1s
n = np.arange(-(N/2), N/2)
t = n / Fe
T = 1/3    # taille de la porte
fBF = 4    # basse fréquence du sinus
fHF = 48   # haute fréquence du sinus
t0 = 1/20  # paramètre de l'exponentielle
miny = 0.01      # minimum d'axe vertical
res_tau = 0.005  # résolution du slider de temps

dirac = np.zeros(t.shape)
dirac[t == 0] = 1
dirdec = np.zeros(t.shape)
dirdec[t == 0.1] = 1

por = np.zeros(t.shape)
por[np.abs(t) < T/2] = 1

porcaus = np.zeros(t.shape)
porcaus[t < T/2] = 1
porcaus[t < 0] = 0

expo = np.exp(-t / t0) * (t >= 0)  # exponentielle causale
expo = expo/np.sum(expo)*Fe    # pour équivalent de tension 1V avec porte

axef = np.arange(N)/N*Fe - Fe/2  # axe fréquentiel
fc = (fBF + fHF) / 2  # fréquence de coupure idéale
filtre = abs(axef) < fc
rifiltre = np.real(np.fft.fftshift(np.fft.ifft(np.fft.fftshift(filtre))))*por
rifiltre = rifiltre*Fe  # normalisation pour affichage sympa

sinBF = np.sin(2*np.pi*fBF*t)
sinHF = np.sin(2*np.pi*fHF*t)

# dictionnaire des signaux
dic_signals = {}
dic_signals["dirac"] = dirac
dic_signals["dirac retardé"] = dirdec
dic_signals["porte"] = por
dic_signals["porte causale"] = porcaus
dic_signals["r.i. exponentielle"] = expo
dic_signals["sinus cardinal tronq"] = rifiltre
dic_signals["sinus basse fréq"] = sinBF
dic_signals["sinus haute fréq"] = sinHF
list_signals = ["dirac", "dirac retardé", "porte", "porte causale",
                "r.i. exponentielle", "sinus cardinal tronq",
                "sinus basse fréq", "sinus haute fréq"]


# Classe de base #####################################################
class Convolution:
    def __init__(self, t, dic_signals, list_signals):
        # Initialisation des données #
        self.t = t
        self.dic_signals = dic_signals
        self.list_signals = list_signals
        self.sig1 = self.dic_signals["dirac"]
        self.sig2 = self.dic_signals["dirac"]
        self.fd1 = True  # flag a-t-on un dirac en signal 1 ?
        self.fd2 = True  # idem signal 2
        self.tau = 0

        # Initialisation de l'affichage #
        self.root = tkinter.Tk()
        self.root.wm_title("Étude de la convolution")

        # Défaut en plein écran #
        xmax = self.root.winfo_screenwidth()
        ymax = self.root.winfo_screenheight()-20
        self.root.geometry(str(xmax) + "x" + str(ymax) + "+0+0")

        self.fig = Figure()
        self.a1 = self.fig.add_subplot(411)
        self.a2 = self.fig.add_subplot(412)
        self.a3 = self.fig.add_subplot(413)
        self.a4 = self.fig.add_subplot(414)
        self.fig.subplots_adjust()

        menu_width = 150

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.root)
        self.canvas.get_tk_widget().place(width=xmax-menu_width, height=ymax)

        self.canvas.mpl_connect("key_press_event", self.key_bindings)

        # Création des widgets #
        self.frame = tkinter.Frame(self.root, width=0.5)
        self.frame.place(x=xmax-menu_width, width=menu_width, height=ymax)

        self.scale_tau = tkinter.Scale(master=self.frame, orient='horizontal',
                                       from_=-0.5, to=0.5, resolution=res_tau,
                                       length=menu_width-20,
                                       command=self.change_tau)

        self.scale_tau.set(self.tau)
        labeltau = tkinter.Label(self.frame, text="valeur de t")
        labeltau.place(rely=0.83, x=30)
        self.scale_tau.place(rely=0.83, x=10, y=20)  # 20 : taille du label

        # Création des menus de choix de signaux #

        self.x = tkinter.StringVar()
        self.choix_x = tkinter.OptionMenu(self.frame, self.x,
                                          *self.list_signals,
                                          command=self.change_x)
        self.x.set(self.list_signals[0])
        labelx = tkinter.Label(self.frame, text="signal x(t)")
        labelx.place(rely=0.1, x=10)
        self.choix_x.place(rely=0.1, x=5, y=20)  # 20 : taille du label

        self.y = tkinter.StringVar()
        self.choix_y = tkinter.OptionMenu(self.frame, self.y,
                                          *self.list_signals,
                                          command=self.change_y)
        self.y.set(self.list_signals[0])
        labely = tkinter.Label(self.frame, text="signal y(t)")
        labely.place(rely=0.3, x=10)
        self.choix_y.place(rely=0.3, x=5, y=20)  # 20 : taille du label

        # Affichage initial #
        self.affiche()
        self.fig.tight_layout()  # répétition de affiche()
        self.canvas.draw()       # pour initialiser correctement

    # callback des différents choix #
    def change_tau(self, tau):
        self.tau = float(tau)
        self.affiche()

    def change_x(self, signal):
        self.sig1 = self.dic_signals[self.x.get()]
        if signal in ["dirac", "dirac retardé"]:
            self.fd1 = 1
        else:
            self.fd1 = 0
        self.affiche()

    def change_y(self, signal):
        self.sig2 = self.dic_signals[self.y.get()]
        if signal in ["dirac", "dirac retardé"]:
            self.fd2 = 1
        else:
            self.fd2 = 0
        self.affiche()

    # fonction d'affichage #
    def affiche(self):
        self.affiche1()
        self.affiche2()
        self.affiche3()
        self.affiche4()
        self.fig.tight_layout()
        self.canvas.draw()

    def affiche1(self):
        self.a1.clear()
        tmp = self.a1.plot(self.t, self.sig1, linewidth=2)
        if self.fd1:
            nd = self.sig1.argmax() - (self.t == 0).argmax()
            self.a1.plot(((nd-100)/Fe, nd/Fe), (0.8, 1),
                         color=tmp[0].get_color(), linewidth=2)
            self.a1.plot((nd/Fe, (nd+100)/Fe), (1, 0.8),
                         color=tmp[0].get_color(), linewidth=2)
        self.a1.set_xlim((-0.5, 0.5))
        ymax = max(max(abs(self.sig1)), miny)
        self.a1.set_ylim((-1.1*ymax, 1.1*ymax))
        self.a1.set_title('x(u)')
        self.a1.grid("on")
        self.a1.set_ylabel('amplitude')
        self.a1.set_xlabel('temps u')

    def affiche2(self):
        self.a2.clear()
        tmp = self.a2.plot(self.t, self.sig2, linewidth=2)
        tmpt = self.a2.plot(self.t+self.tau, self.sig2[::-1], linewidth=2)
        if self.fd2:
            nd = self.sig2.argmax() - (self.t == 0).argmax()
            self.a2.plot(((nd-100)/Fe, nd/Fe), (0.8, 1),
                         color=tmp[0].get_color(), linewidth=2)
            self.a2.plot((nd/Fe, (nd+100)/Fe), (1, 0.8),
                         color=tmp[0].get_color(), linewidth=2)
            nd = self.sig2[::-1].argmax() - ((self.t+self.tau) == 0).argmax()
            self.a2.plot(((nd-100)/Fe, nd/Fe), (0.8, 1),
                         color=tmpt[0].get_color(), linewidth=2)
            self.a2.plot((nd/Fe, (nd+100)/Fe), (1, 0.8),
                         color=tmpt[0].get_color(), linewidth=2)
        self.a2.set_xlim((-0.5, 0.5))
        self.a2.set_xlim((-0.5, 0.5))
        ymax = max(max(abs(self.sig2)), miny)
        self.a2.set_ylim((-1.1*ymax, 1.1*ymax))
        self.a2.set_title('y(u) et y(t-u)')
        self.a2.grid("on")
        self.a2.set_ylabel('amplitude')
        self.a2.set_xlabel('temps u')

    def affiche3(self):
        ntau = np.round(self.tau*Fe)
        s1 = self.sig1[np.abs(n) < (Fe/2)]
        s2 = self.sig2[np.abs(n - ntau) < (Fe/2)][::-1]
        # equivalent to abs(t+tau) < 0.5, without round errors
        axet = self.t[abs(self.t) < 0.5]
        self.a3.clear()
        tmp = self.a3.plot(axet, s1*s2, linewidth=2)
        if self.fd1 or self.fd2:
            nd = abs(s1*s2).argmax() - (axet == 0).argmax()
            ampl = (s1*s2)[s1*s2 != 0]
            if ampl.size > 0:
                self.a3.plot(((nd-100)/Fe, nd/Fe), (0.8*ampl, ampl),
                             color=tmp[0].get_color(), linewidth=2)
                self.a3.plot((nd/Fe, (nd+100)/Fe), (ampl, 0.8*ampl),
                             color=tmp[0].get_color(), linewidth=2)
        self.a3.fill_between(axet, 0, s1*s2, color="green", alpha=0.2)
        self.a3.set_xlim((-0.5, 0.5))
        ymax = max(max(abs(s1*s2)), miny)
        self.a3.set_ylim((-1.1*ymax, 1.1*ymax))
        self.a3.set_title('x(u) × y(t-u)')
        self.a3.grid("on")
        self.a3.set_ylabel('amplitude')
        self.a3.set_xlabel('temps u')

    def affiche4(self):
        self.a4.clear()
        r = np.convolve(self.sig1, self.sig2, mode="full")
        lags = np.arange(-len(self.sig1)+1, len(self.sig1))
        if not(self.fd1) and not(self.fd2):  # pas d'élément neutre
            r = r/Fe  # normalisation pour le continu
        tmp = self.a4.plot(lags/Fe, r, linewidth=2)
        ymax = max(max(r), miny)
        ymin = -ymax
        self.a4.plot((self.tau, self.tau), (1.5*ymin, 1.5*ymax),
                     color='r', linewidth=2)
        if self.fd1 and self.fd2:
            nd = r.argmax() - (lags == 0).argmax()
            self.a4.plot(((nd-100)/Fe, nd/Fe), (0.8, 1),
                         color=tmp[0].get_color(), linewidth=2)
            self.a4.plot((nd/Fe, (nd+100)/Fe), (1, 0.8),
                         color=tmp[0].get_color(), linewidth=2)
        self.a4.set_xlim((-0.5, 0.5))
        self.a4.set_ylim((1.2*ymin, 1.2*ymax))
        self.a4.set_title(r'$z(t) = x(t)*y(t) = \int x(u)y(t-u)du$')
        self.a4.grid("on")
        self.a4.set_ylabel('amplitude')
        self.a4.set_xlabel('temps t')

    def key_bindings(self, event):
        if event.key in ["x"]:  # x(t) suivant
            idx = self.list_signals.index(self.x.get())
            new_idx = (idx + 1) % len(self.list_signals)
            self.x.set(self.list_signals[new_idx])
            self.change_x(self.list_signals[new_idx])

        if event.key in ["X"]:  # x(t) précédent
            idx = self.list_signals.index(self.x.get())
            new_idx = (idx - 1) % len(self.list_signals)
            self.x.set(self.list_signals[new_idx])
            self.change_x(self.list_signals[new_idx])

        if event.key in ["y"]:  # y(t) suivant
            idx = self.list_signals.index(self.y.get())
            new_idx = (idx + 1) % len(self.list_signals)
            self.y.set(self.list_signals[new_idx])
            self.change_y(self.list_signals[new_idx])

        if event.key in ["Y"]:  # y(t) précédent
            idx = self.list_signals.index(self.y.get())
            new_idx = (idx - 1) % len(self.list_signals)
            self.y.set(self.list_signals[new_idx])
            self.change_y(self.list_signals[new_idx])

        if event.key in ["q", "Q"]:  # quitter
            self.root.quit()
            self.root.destroy()
        if event.key in ["left"]:  # diminuer τ
            self.scale_tau.set(self.tau-res_tau)
        if event.key in ["right"]:  # augmenter τ
            self.scale_tau.set(self.tau+res_tau)
        if event.key in ["ctrl+left"]:  # plus diminuer τ
            self.scale_tau.set(self.tau-10*res_tau)
        if event.key in ["ctrl+right"]:  # plus augmenter τ
            self.scale_tau.set(self.tau+10*res_tau)


conv = Convolution(t, dic_signals, list_signals)
tkinter.mainloop()
