#
# Visualization of Numeric arrays using PIL.
#
# Rob W.W. Hooft, Nonius BV, 1999
# Distribute freely.
#
import Numeric

# I have a C version of this as well to speed it up.
def logscale(data,low,high):
    data=data-low+1
    fac=Numeric.array([255.99/Numeric.log(high-low+1)],'f')
    return (Numeric.log(data)*fac).astype(Numeric.Int8)

class LogScale:
    """Logarithmic scale"""
    def __call__(self,data):
	mindata=Numeric.minimum.reduce(data.flat)
	maxdata=Numeric.maximum.reduce(data.flat)
        from misc import ntools
	return logscale(data,mindata,maxdata)

# I have a C version of this as well to speed it up.
def linscale(data,low,high):
    return ((data-low)*Numeric.array([255.99/(high-low)],'f')).astype(Numeric.Int8)

class LinearScale:
    """Linear scale"""
    def __call__(self,data):
	mindata=Numeric.minimum.reduce(data.flat)
	maxdata=Numeric.maximum.reduce(data.flat)
	return linscale(data,mindata,maxdata)

# Scaling class that limits the extend of the data.  (I have a GUI subclass of
# this one which enables the user to change all the bounds. That is why it might
# look a bit convoluted)
class BoundedScale:
    def __init__(self,scaleobj):
	self.scaleobj=scaleobj
        self.newbounds()

    def newbounds(self):
        self.first=1

    def newscale(self,scaleobj):
	self.scaleobj=scaleobj
        
    def __call__(self,data):
	self.limits(data)
        if self.first:
            self.first=0
	    self.autobounds(data)
            print "Image will be scaled between %f and %f"%(self.getlow(),self.gethigh())
	if self.getlow()>self.lowlimit or self.gethigh()<self.highlimit:
	    t=self.scaleobj(Numeric.clip(data,self.getlow(),self.gethigh()))
	else:
	    t=self.scaleobj(data)
	return t

    def limits(self,data):
        d=Numeric.ravel(data)
	self.lowlimit=Numeric.minimum.reduce(d)
	self.highlimit=Numeric.maximum.reduce(d)

    def autobounds(self,data):
	self.setlow(self.lowlimit)
	self.sethigh(self.highlimit)

    def limitbounds(self):
        if self.getlow()>self.gethigh():
            save=self.getlow()
            self.setlow(self.gethigh())
            self.sethigh(save)
	if self.getlow()<self.lowlimit:
	    self.setlow(self.lowlimit)
	if self.gethigh()>self.highlimit:
	    self.sethigh(self.highlimit)

    def sethigh(self,val):
        self.high1=val

    def setlow(self,val):
        self.low1=val

    def getlow(self):
        return self.low1

    def gethigh(self):
        return self.high1

# Statistical helper functions
def sortedmedian(row,fraction=0.5):
    return row[int(fraction*(len(row)-1))]

def median(row,fraction=0.5):
    s=Numeric.sort(row)
    if type(fraction)==type(()) or type(fraction)==type([]):
	r=[]
	for fr in fraction:
	    r.append(sortedmedian(s,fr))
	return tuple(r)
    else:
	return sortedmedian(s,fraction)

# Two subclasses that try to get the most contrast out of different kinds of greyscale data.
class MedianBounds(BoundedScale):
    def autobounds(self,data):
	self.lowlimit,low,high,self.highlimit=median(Numeric.ravel(data),fraction=(0.0,0.1,0.9,1.0))
        self.sethigh(high)
        self.setlow(low)

class AverageBounds(BoundedScale):
    def autobounds(self,data):
	f=Numeric.add.reduce(data)/len(data)
	f=Numeric.add.reduce(f)/len(f)
	self.setlow(0.9*f)
	self.sethigh(1.25*f)

# Two small helper functions to create a palette for PIL
def colorstr(c):
    return chr(c[0])+chr(c[1])+chr(c[2])

def colorramp(c1,c2,n):
    r=[]
    for i in range(n):
        r.append(colorstr((round((i*c2[0]+(n-1-i)*c1[0])/(n-1.0)),
                           round((i*c2[1]+(n-1-i)*c1[1])/(n-1.0)),
                           round((i*c2[2]+(n-1-i)*c1[2])/(n-1.0)))))
    return r

# Palette generator function.
def Palette(palette):
    if len(palette)!=256:
        print "Invalid palette length!!"
        return
    import ImagePalette,string
    return ImagePalette.raw("RGB", string.join(palette, ""))

class DisplaySink:
    def __init__(self,scalefunc=None):
        self.newscale(scalefunc)
        self.ChoosePalette(0)

    def newscale(self,scalefunc):
        self.scalefunc=scalefunc
        
    def ChoosePalette(self,ipalette):
        """Choose from a number of standard palettes"""
        if ipalette==0:
            # White on black
            self.palette=(colorramp((0,0,0),(255,255,255),256))
        elif ipalette==1:
            # Black on white
            self.palette=(colorramp((255,255,255),(0,0,0),256))
        elif ipalette==2:
            # White on black, with extremes marked as blue (cold) and red (hot)
            self.palette=([colorstr((0,0,255))]+
                          colorramp((0,0,0),(255,255,255),254)+
                          [colorstr((255,0,0))])
        elif ipalette==3:
            # Color ramp.
            self.palette=(colorramp((0,0,0),(0,0,255),48)+
                          colorramp((0,0,255),(0,255,0),80)+
                          colorramp((0,255,0),(255,0,0),80)+
                          colorramp((255,0,0),(255,255,255),48))
        elif ipalette==4:
	    # Another color ramp
            self.palette=(colorramp((0,0,0),(0,0,255),32)+
                          colorramp((0,0,255),(255,0,0),96)+
                          colorramp((255,0,0),(255,255,0),64)+
                          colorramp((255,255,0),(255,255,255),64))
        else:
            print "Invalid palette selected"
            self.ChoosePalette(0)
            
    def Widget(self,parent):
        import Tkinter
	self.c=Tkinter.Label(parent,bd=0)
	return self.c
                
    def datatoimage(self,data):
        import Image
        im=Image.fromstring("P",(data.shape[1],data.shape[0]),
                            self.scalefunc(data).astype(Numeric.Int8).tostring(),"raw","P")
        im.palette=Palette(self.palette)
        return im

    def showimage(self,data):
        import ImageTk
	self.pili=self.datatoimage(data)
	self.p=ImageTk.PhotoImage(self.pili)
	self.c.configure(image=self.p)

# Example data for Numeric.
def gen(i,j):
    return Numeric.cos(2.0*Numeric.pi*i/25.0)*Numeric.cos(2.0*Numeric.pi*j/25.0)

if __name__=="__main__":
    # Any 2D array of float or integer data.
    data=Numeric.fromfunction(gen,(151,101))
    # Scaling function. Any function that takes an array, and returns an array of byte-values between
    # 0 and 255 will do, but there are a number of classes above to help generate such functions.
    scal=MedianBounds(LinearScale())
    # The display object
    d=DisplaySink(scal)
    # Set palette. Red=high, Blue=low, grey ramp inbetween
    d.ChoosePalette(2)
    # Generate and pack the display widget inside the root window.
    w=d.Widget(None)
    w.pack()
    # Send the data to the display object.
    d.showimage(data)
    # Wait to be killed.
    w.mainloop()
    
