#{{{ Import modules

import numpy as np
pi = np.pi
array = np.array
sqrt = np.lib.scimath.sqrt
from numpy.linalg import norm

from enthought.traits.api import HasTraits, Int, Float, CFloat, CComplex, CArray, List, Str
from enthought.traits.ui.api import View, Group, HGroup, VGroup, Item, Readonly, spring,\
     Label, Handler
from enthought.traits.ui.menu import Action, MenuBar, Menu, CloseAction
from enthought.enable.component_editor import ComponentEditor
from enthought.chaco.api import marker_trait, Plot, ArrayPlotData, Legend
from enthought.chaco.tools.api import LegendTool, PanTool, ZoomTool, RectZoomTool
from enthought.kiva.fonttools.font import Font
from unit import *
import optics.geometric
import copy
import sdxf

#}}}

#{{{ GaussianBeam Class

class GaussianBeam(HasTraits):
    '''
    A class to represent a Gaussian beam at a particular location.
    The class holds the q parameter, the location and the direction
    of propagation.

    Attributes:
    qx: q-parameter of the beam in the x-direction
    qy: q-parameter of the beam in the y-direction
    pos: Coordinates of the beam
    dirVect: Propagation direction vector
    dirAngle: Propagation direction angle measured from the positive x-axis.
    '''

#{{{ Traits Definitions

    name = Str()
    wl = CFloat(1064.0*nm)  #Wavelength
    P = CFloat(1*W)  #Power
    qx = CComplex()  #q-parameter at the origin (x-direction)
    qy = CComplex()  #q-parameter at the origin (y-direction)
    Gouyx = CFloat(0.0) #Accumurated Gouy phase
    Gouyy = CFloat(0.0) #Accumurated Gouy phase
    wx = CFloat()
    wy = CFloat()
    n = CFloat(1.0)

    pos = CArray(dtype=np.float64, shape=(2,))
    length = CFloat(1.0)
    layer = Str()
    dirVect = CArray(dtype=np.float64, shape=(2,))
    dirAngle = CFloat()
    optDist = CFloat(0.0)

#}}}

#{{{ __init__

    def __init__(self, q0=1j*2*pi/(1064*nm)*1e-6/2, q0x=False, q0y=False,
                 pos=[0.0,0.0], length=1.0, dirAngle=0.0, dirVect=False,
                 wl=1064*nm, P=1*W, n=1.0, name="Beam", layer='main_beam'):

        self.wl = wl
        self.P = P
        self.pos = pos
        self.length = length
        self.name = name
        self.layer = layer
        self.n = n
        
        if q0x:
            self.qx = q0x
        else:
            self.qx = q0
            
        if q0y:
            self.qy = q0y
        else:
            self.qy = q0

        if dirVect:
            self.dirVect = dirVect
        else:
            self.dirAngle = dirAngle
            self._dirAngle_changed(0,0)

        self.optDist = 0.0


#}}}

#{{{ copy

    def copy(self):
        return copy.deepcopy(self)

#}}}

#{{{ propagate

    def propagate(self, d):
        '''
        Propagate the beam by d from the current position.
        '''
        qx0 = self.qx
        qy0 = self.qy
        
        self.qx = qx0 + d/self.n
        self.qy = qy0 + d/self.n
        self.pos = self.pos + self.dirVect*d

        #Increase the optical distance
        self.optDist = self.optDist + self.n*d

        #Increase the Gouy phase
        self.Gouyx = self.Gouyx + np.arctan(np.real(self.qx)/np.imag(self.qx))\
                     - np.arctan(np.real(qx0)/np.imag(qx0))

        self.Gouyy = self.Gouyy + np.arctan(np.real(self.qy)/np.imag(self.qy))\
                     - np.arctan(np.real(qy0)/np.imag(qy0))

#}}}

#{{{ ABCD Trans
    def ABCDTrans(self, Mx, My=None):
        '''
        Apply ABCD transformation to the beam
        '''
        self.qx = (Mx[0,0]*self.qx + Mx[0,1])/(Mx[1,0]*self.qx + Mx[1,1])
        if not My == None:
            self.qy = (My[0,0]*self.qy + My[0,1])/(My[1,0]*self.qy + My[1,1])
        else:
            self.qy = (Mx[0,0]*self.qy + Mx[0,1])/(Mx[1,0]*self.qy + Mx[1,1])
            
        
#}}}

#{{{ rotate

    def rotate(self, angle, center=False):
        '''
        Rotate the beam around 'center'.
        If center is not given, the beam is rotated
        around the self.pos.
        '''
        
        if center:
            center = np.array(center)
            pointer = self.pos - center
            pointer = optics.geometric.vector_rotation_2D(pointer, angle)
            self.pos = center + pointer
            
        self.dirAngle = self.dirAngle + angle

#}}}

#{{{ Translate

    def translate(self, trVect):
        trVect = np.array(trVect)
        self.pos = self.pos + trVect
        
#}}}

#{{{ Flip
    def flip(self, flipDirVect=True):
        '''
        Change the propagation direction of the beam
        by 180 degrees.
        '''
        self.qx = -np.real(self.qx)+1j*np.imag(self.qx)
        self.qy = -np.real(self.qy)+1j*np.imag(self.qy)
        if flipDirVect:
            self.dirVect = - self.dirVect

#}}}

#{{{ width

    def width(self, dist):
        '''
        Returns the beam width at a distance dist
        from the pos.
        '''

        dist = np.array(dist)
        k = 2*pi/self.wl
        qx = self.qx + dist/self.n
        qy = self.qy + dist/self.n

        return (np.sqrt(-2.0/(k*np.imag(1.0/qx))), np.sqrt(-2.0/(k*np.imag(1.0/qy))))

#}}}

#{{{ draw

    def draw(self, dxf, sigma=3., mode='avg', drawWidth=True, fontSize=1,
             drawPower=False, drawROC=False, drawGouy=False, drawOptDist=False, debug=False):
        '''
        Draw itself into a DXF file.
        '''
        start = tuple(self.pos)
        stop = tuple(self.pos + self.dirVect * self.length)

        #Draw the center line
        dxf.append(sdxf.Line(points=[start, stop],
                   layer=self.layer))
        
        if drawWidth:
            #Draw the width
            zr = q2zr(self.qx)
            resolution = zr/10.0
            if resolution > self.length/10.0:
                resolution = self.length/10.0

            numSegments = self.length/resolution
            
            if numSegments > 100:
                numSegments = 100
            
            d = np.linspace(0,self.length, numSegments)
            a = self.width(d)
            if mode == 'x':
                w = a[0]*sigma
            elif mode == 'y':
                w = a[1]*sigma
            else:
                w = sigma*(a[0]+a[1])/2

            v = np.vstack((d,w))
            v = optics.geometric.vector_rotation_2D(v, self.dirAngle)
            v = v + np.array([self.pos]).T
            dxf.append(sdxf.LwPolyLine(points=list(v.T), layer=self.layer+"_width"))

            v = np.vstack((d,-w))
            v = optics.geometric.vector_rotation_2D(v, self.dirAngle)
            v = v + np.array([self.pos]).T
            dxf.append(sdxf.LwPolyLine(points=list(v.T), layer=self.layer+"_width"))

        text_location = start
        if drawPower:
            dxf.append(sdxf.Text(text='P='+str(self.P), point=text_location,
                                 height=fontSize, layer='text'))
            text_location=(text_location[0], text_location[1]+fontSize*1.2)

        if drawROC:
            dxf.append(sdxf.Text(text='ROCx='+str(q2R(self.qx)), point=text_location,
                                 height=fontSize, layer='text'))
            text_location=(text_location[0], text_location[1]+fontSize*1.2)
            
            dxf.append(sdxf.Text(text='ROCy='+str(q2R(self.qy)), point=text_location,
                                 height=fontSize, layer='text'))
            text_location=(text_location[0], text_location[1]+fontSize*1.2)            

        if drawGouy:
            dxf.append(sdxf.Text(text='Gouyx='+str(self.Gouyx), point=text_location,
                                 height=fontSize, layer='text'))
            text_location=(text_location[0], text_location[1]+fontSize*1.2)
            
            dxf.append(sdxf.Text(text='Gouyy='+str(self.Gouyy), point=text_location,
                                 height=fontSize, layer='text'))
            text_location=(text_location[0], text_location[1]+fontSize*1.2)            

        if drawOptDist:
            dxf.append(sdxf.Text(text='Optical distance='+str(self.optDist), point=text_location,
                                 height=fontSize, layer='text'))
            text_location=(text_location[0], text_location[1]+fontSize*1.2)

       

#}}}

#{{{ Notification Handlers

    def _dirAngle_changed(self, old, new):
        self.set(trait_change_notify=False,
                 dirVect=array([np.cos(self.dirAngle), np.sin(self.dirAngle)]))
        self.set(trait_change_notify=False,
                 dirAngle = np.mod(self.dirAngle, 2*pi))
#        self.dirVect = array([np.cos(self.dirAngle), np.sin(self.dirAngle)])
#        self.dirAngle = np.mod(self.dirAngle, 2*pi)
            
    def _dirVect_changed(self, old, new):
        #Normalize
        self.set(trait_change_notify=False,
                 dirVect = self.dirVect/np.linalg.norm(array(self.dirVect)))
        #Update dirAngle accordingly
        self.set(trait_change_notify=False,
                 dirAngle = np.mod(np.arctan2(self.dirVect[1],
                                              self.dirVect[0]), 2*pi))

    def _qx_changed(self, old, new):
        self.wx = q2w(self.qx, wl=self.wl)

    def _qy_changed(self, old, new):
        self.wy = q2w(self.qy, wl=self.wl)

#}}}        

#}}}

#{{{ Utility Functions

def q2zr(q):
    '''
    Convert a q-parameter to Rayleigh range.

    '''
    zr = np.float(np.imag(q))
    return zr

def q2w(q, wl=1064*nm):
    '''
    Convert a q-parameter to the beam size
    '''
    S = -1.0/np.imag(1.0/q)
    w = np.sqrt(pi*S/wl)
    return w


def q2R(q):
    '''
    Convert a q-parameter to the ROC
    '''
    return 1.0/np.real(1.0/q)

def Rw2q(ROC=1.0, w=1.0, wl=1064e-9):
    '''
    Get the q-parameter from the ROC and w.
    '''
    k = 2.0*pi/wl
    S = w**2 * k/2

    return 1.0/(1.0/ROC + 1.0/(1j*S))

def zr2w0(zr, wl=1064*nm):
    '''
    Convert Rayleigh range to the waist size
    '''
    return np.sqrt(2*zr*wl/(2*pi))

def w02zr(w0, wl=1064*nm):
    '''
    Convert Rayleigh range to the waist size
    '''
    return (2*pi/wl)*(w0**2)/2

    
#}}}
