#!/usr/bin/python2

# Copyright (c) 2007 Johns Hopkins University.
# All rights reserved.
#
 # Redistribution and use in source and binary forms, with or without
 # modification, are permitted provided that the following conditions
 # are met:
 #
 # - Redistributions of source code must retain the above copyright
 #   notice, this list of conditions and the following disclaimer.
 # - Redistributions in binary form must reproduce the above copyright
 #   notice, this list of conditions and the following disclaimer in the
 #   documentation and/or other materials provided with the
 #   distribution.
 # - Neither the name of the copyright holders nor the names of
 #   its contributors may be used to endorse or promote products derived
 #   from this software without specific prior written permission.
 #
 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 # FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL
 # THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
 # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 # STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 # OF THE POSSIBILITY OF SUCH DAMAGE.

# @author Razvan Musaloiu-E. <razvanm@cs.jhu.edu>
# @author Chieh-Jan Mike Liang <cliang4@cs.jhu.edu>

import sys, stat, struct, subprocess, time, os.path
try:
    import tos
except ImportError:
    import posix
    sys.path = [os.path.join(posix.environ['TOSROOT'], 'support', 'sdk', 'python')] + sys.path
    import tos
from datetime import datetime

# Path to the python script that builds Deluge image from XML
PATH_PY_BUILD_IMAGE  = os.path.join(os.path.dirname(sys.argv[0]), 'tos-build-deluge-image')

# TinyOS serial communication parameters
FM_AMID = 0x53
DM_AMID = 0x54
SERIAL_DATA_LENGTH = 28 - (1+1+4+2)  # See definition of FMReqPacket below

# Commands for FlashManager
FM_CMD_ERASE     = 0
FM_CMD_WRITE     = 1
FM_CMD_READ      = 2
FM_CMD_CRC       = 3
FM_CMD_ADDR      = 4
FM_CMD_SYNC      = 5
FM_CMD_IDENT     = 6

# Commands for DelugeManager
DM_CMD_STOP             = 1
DM_CMD_LOCAL_STOP       = 2
DM_CMD_ONLY_DISSEMINATE = 3
DM_CMD_DISSEMINATE_AND_REPROGRAM = 4
DM_CMD_REPROGRAM        = 5
DM_CMD_BOOT             = 6

ERROR_SUCCESS = 0   # T2-compatible
ERROR_FAIL    = 1   # T2-compatible

# Deluge parameters
DELUGE_MAX_PAGES    = 128
DELUGE_IDENT_OFFSET = 0
DELUGE_IDENT_SIZE   = 128

class FMReqPacket(tos.Packet):
    def __init__(self, packet = None):
        tos.Packet.__init__(self,
                            [('cmd',    'int',  1),
                             ('imgNum', 'int',  1),
                             ('offset', 'int',  4),
                             ('length', 'int',  2),
                             ('data',   'blob', None)],
                            packet)

class DMReqPacket(tos.Packet):
    def __init__(self, packet = None):
        tos.Packet.__init__(self,
                            [('cmd',    'int',  1),
                             ('imgNum', 'int',  1)],
                            packet)

class SerialReplyPacket(tos.Packet):
   def __init__(self, packet = None):
       tos.Packet.__init__(self,
                           [('error', 'int',  1),
                            ('data',  'blob', None)],
                           packet)

class Ident(tos.Packet):
    def __init__(self, packet = None):
        tos.Packet.__init__(self,
                            [('uidhash',  'int', 4),
                             ('size',     'int', 4),
                             ('pages',    'int', 1),
                             ('reserved', 'int', 1),
                             ('crc',      'int', 2),
                             ('appname', 'string', 16),
                             ('username', 'string', 16),
                             ('hostname', 'string', 16),
                             ('platform', 'string', 16),
                             ('timestamp','int', 4),
                             ('userhash', 'int', 4)],
                            packet)

class ShortIdent(tos.Packet):
    def __init__(self, packet = None):
        tos.Packet.__init__(self,
                            [('appname',  'string', 16),
                             ('timestamp','int', 4),
                             ('uidhash',  'int', 4),
                             ('nodeid',   'int', 2)],
                            packet)


# Computes 16-bit CRC
def crc16(data):
    crc = 0
    for b in data:
        crc = crc ^ (b << 8)
        for i in range(0, 8):
            if crc & 0x8000 == 0x8000:
                crc = (crc << 1) ^ 0x1021
            else:
                crc = crc << 1
            crc = crc & 0xffff

    return crc

def handleResponse(success, msg):
    if success == True:
        packet = am.read(timeout=2)
        while packet and packet.type == 100:
            print "".join([chr(i) for i in packet.data])
            packet = am.read()
        if not packet:
            print "No response"
            return False
        reply = SerialReplyPacket(packet.data)
        if reply.error == ERROR_SUCCESS:
            return True
        else:
            print msg, reply
            return False

    print "ERROR: Unable to send the command"
    return False

def ident(timeout=None):
    sreqpkt = FMReqPacket((FM_CMD_IDENT, 0, 0, 0, []))
    if am.write(sreqpkt, FM_AMID, timeout=timeout):
        packet = am.read(timeout=timeout)
        reply = SerialReplyPacket(packet.data)
        if reply.error == ERROR_SUCCESS:
            return ShortIdent(reply.data)
    return 0

def read(imgNum, offset, length):
    r = []

    sreqpkt = FMReqPacket((FM_CMD_READ, imgNum, offset, length, []))
    while True:
        if sreqpkt.length > SERIAL_DATA_LENGTH:
            sreqpkt.length = SERIAL_DATA_LENGTH

        if am.write(sreqpkt, FM_AMID):
            packet = am.read()
            reply = SerialReplyPacket(packet.data)
            if reply.error == ERROR_SUCCESS:
                r.extend(reply.data)
            else:
                r = None
                break
        else:
            r = None
            break

        sreqpkt.offset += sreqpkt.length
        if sreqpkt.offset >= (offset + length):
            break
        sreqpkt.length = (offset + length) - sreqpkt.offset

    return r

def erase(imgNum):
    # Note: the normal erase doesn't work properly on AT45DB. A
    # workaround is to do the normal erase (to make happy STM25P)
    # and then overwrite the metadata (to make happy AT45DB).

    sreqpkt = FMReqPacket((FM_CMD_ERASE, imgNum, 0, 0, []))
    success = am.write(sreqpkt, FM_AMID)
    result = handleResponse(success, "ERROR: Unable to erase the flash volume")
    if result: return True;

    print 'Attempt the workaround for AT45DB...'
    sreqpkt = FMReqPacket((FM_CMD_WRITE, imgNum, 0, 0, []))
    sreqpkt.data = [0xFF] * DELUGE_IDENT_SIZE
    sreqpkt.length = DELUGE_IDENT_SIZE
    success = am.write(sreqpkt, FM_AMID)
    result = handleResponse(success, "ERROR: Unable to erase the flash volume")
    if not result: return False;
    return sync(imgNum)

def sync(imgNum):
    sreqpkt = FMReqPacket((FM_CMD_SYNC, imgNum, 0, 0, []))
    success = am.write(sreqpkt, FM_AMID)
    return handleResponse(success, "ERROR: Unable to sync the flash volume")

def write(imgNum, data):
    sreqpkt = FMReqPacket((FM_CMD_WRITE, imgNum, 0, 0, []))
    length = len(data)
    total_length = length   # For progress bar
    next_tick = 100         # For progress bar
    start_time = time.time()

    print "[0%        25%         50%         75%         100%]\r[",

    sreqpkt.offset = 0
    while length > 0:
        if ((length * 100) / total_length) < next_tick:
            next_tick = next_tick - 2
            sys.stdout.write('-')
            sys.stdout.flush()

        # Calculates the payload size for the current packet
        if length >= SERIAL_DATA_LENGTH:
            sreqpkt.length = SERIAL_DATA_LENGTH
        else:
            sreqpkt.length = length
        sreqpkt.data = data[sreqpkt.offset:sreqpkt.offset+sreqpkt.length]

        # Sends over serial to the mote
        if not am.write(sreqpkt, FM_AMID):
            print
            print "ERROR: Unable to send the last serial packet (file offset: %d)" % sreqpkt.offset
            return False

        # Waiting for confirmation
        packet = am.read()
        reply = SerialReplyPacket(packet.data)
        if reply.error != ERROR_SUCCESS:
            print
            print "ERROR: Unable to write to the flash volume (file offset: %d)" % sreqpkt.offset
            return False

        length -= sreqpkt.length
        sreqpkt.offset += sreqpkt.length

    print '\r' + ' ' * 52,
    elasped_time = time.time() - start_time
    print "\r%s bytes in %.2f seconds (%.4f bytes/s)" % (total_length, elasped_time, int(total_length) / (elasped_time))
    return True

# Checks for valid CRC and timestamp
def verifyIdent(i):
    if i != None:
        if crc16(i.payload()[0:10]) == i.crc and i.timestamp != 0xFFFFFFFF:
            return True
        else:
            print "No valid image was detected."
    return False

def getIdent(imgNum):
    r = read(imgNum, DELUGE_IDENT_OFFSET, DELUGE_IDENT_SIZE)
    if r:
        return Ident(r)
    print "ERROR: Unable to retrieve the ident."
    return None

def formatIdent(i):
    r  = "  Prog Name:   %s\n" % (i.appname)
    r += "  UID:         0x%08X\n" % (i.uidhash)
    r += "  Compiled On: %s\n" % (datetime.fromtimestamp(i.timestamp).strftime('%a %h %d %T %Y'))
    r += "  Platform:    %s\n" % (i.platform)
    r += "  User ID:     %s\n" % (i.username)
    r += "  Host Name:   %s\n" % (i.hostname)
    r += "  User Hash:   0x%08X\n" % (i.userhash)
    r += "  Size:        %d\n" % (i.size)
    r += "  Num Pages:   %d" % (i.pages)
    return r

def formatShortIdent(i):
    r  = "  Prog Name:   %s\n" % (i.appname)
    r += "  UID:         0x%08X\n" % (i.uidhash)
    r += "  Compiled On: %s\n" % (datetime.fromtimestamp(i.timestamp).strftime('%a %h %d %T %Y'))
    r += "  Node ID:     %d\n" % (i.nodeid)
    return r

# Injects an image (specified by tos_image_xml) to an image volume
def inject(imgNum, tos_image_xml):
    # Checks for valid file path
    try:
        os.stat(tos_image_xml)         # Checks whether tos_image_xml is a valid file
    except:
        print "ERROR: Unable to find the TOS image XML, \"%s\"" % tos_image_xml
        return False
    try:
        os.stat(PATH_PY_BUILD_IMAGE)   # Checks whether PATH_PY_BUILD_IMAGE is a valid file
    except:
        print "ERROR: Unable to find the image building utility, \"%s\"" % PATH_PY_BUILD_IMAGE
        return False

    # Gets status information of stored image
    i = getIdent(imgNum)
    if ident:
        print "Connected to Deluge nodes."
        if verifyIdent(i):
            print "--------------------------------------------------"
            print "Stored image %d" % imgNum
            print formatIdent(i)
    else:
        return False

    # Creates binary image from the TOS image XML
    print "--------------------------------------------------"
    cmd = [PATH_PY_BUILD_IMAGE, "-i", str(imgNum), tos_image_xml]
    print "Create image:", ' '.join(cmd)
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    (out, err) = p.communicate(None)
    print err,
    print "--------------------------------------------------"

    # Writes the new binary image
    image = [struct.unpack("B", c)[0] for c in out]
    if len(image) > 0 and erase(imgNum):
        if write(imgNum, image):
            if sync(imgNum):
                print "--------------------------------------------------"
                print "Replace image with:"
                print formatIdent(getIdent(imgNum))
                print "--------------------------------------------------"

    return False

def ping(imgNum):
    uid = ident()
    # Prints out image status
    print "--------------------------------------------------"
    print "Currently Executing:"
    print formatShortIdent(ident())
    i = getIdent(imgNum)
    if verifyIdent(i):
        print "Stored image %d" % imgNum
        print formatIdent(i)
        print "--------------------------------------------------"
        return True

    print "--------------------------------------------------"
    return False

def boot():
    sreqpkt = DMReqPacket((DM_CMD_BOOT, 0))
    success = am.write(sreqpkt, DM_AMID)
    return handleResponse(success, "ERROR: Unable to boot the mote")

def reprogram(imgNum):
    sreqpkt = DMReqPacket((DM_CMD_REPROGRAM, imgNum))
    success = am.write(sreqpkt, DM_AMID)
    return handleResponse(success, "ERROR: Unable to reprogram the mote")

def disseminate(imgNum):
    sreqpkt = DMReqPacket((DM_CMD_ONLY_DISSEMINATE, imgNum))
    success = am.write(sreqpkt, DM_AMID)
    return handleResponse(success, "ERROR: Unable to disseminate")

def disseminateAndReboot(imgNum):
    sreqpkt = DMReqPacket((DM_CMD_DISSEMINATE_AND_REPROGRAM, imgNum))
    success = am.write(sreqpkt, DM_AMID)
    return handleResponse(success, "ERROR: Unable to disseminate-and-reboot")

def stop():
    sreqpkt = DMReqPacket((DM_CMD_STOP, 0))
    success = am.write(sreqpkt, DM_AMID)
    return handleResponse(success, "ERROR: Unable to initiate the stop")

def localstop():
    sreqpkt = DMReqPacket((DM_CMD_LOCAL_STOP, 0))
    success = am.write(sreqpkt, DM_AMID)
    return handleResponse(success, "ERROR: Unable to initiate the local stop")

def print_usage():
    print "Usage: %s <source> <-p|-i|-r|-d|-e|-s> image_number [options]" % sys.argv[0]
    print "  <source> can be:"
    print "     serial@PORT:SPEED   Serial ports"
    print "     network@HOST:PORT   MIB600"
    print "  -p --ping        Provide status of the image in the external flash"
    print "  -i --inject      Inject a compiled TinyOS application"
    print "                   [options]: tos_image.xml file path"
    print "  -e --erase       Erase an image in the external flash"
    print "  -b --boot        Force a reboot of the mote"
    print "  -r --reprogram   Reprogram the mote"
    print "  -d --disseminate Disseminate the image in the external flash to the network"
    print "  -dr --disseminate-and-reprogram"
    print "  -s --stop        Stop the dissemination "
    print "  -ls --local-stop Stop the dissemination only on the local mote"

#     print "  -b --reprogram_bs\n     Reprogram only the directly-connected mote"
#     print "  -s --reset\n     Reset the versioning information for a given image"

def checkImgNum():
    global imgNum
    # Checks for valid image number format
    try:
        imgNum = int(sys.argv[3])
    except:
        print "ERROR: Image number is not valid"
        sys.exit(-1)
    return imgNum

# ======== MAIN ======== #
if len(sys.argv) >= 3:

    am = tos.AM()

    try:
        print "Checking if node is a Deluge T2 base station ..."
        ident(timeout=1)
    except tos.Timeout:
        print "ERROR: Timeout. Is the node a Deluge T2 base station?"
        sys.exit(-1)

    if sys.argv[2] in ["-p", "--ping"]:
        checkImgNum()
        print "Pinging node ..."
        ping(imgNum)
    elif sys.argv[2] in ["-i", "--inject"] and len(sys.argv) == 5:
        checkImgNum()
        print "Pinging node ..."
        inject(imgNum, sys.argv[4])
    elif sys.argv[2] in ["-e", "--erase"]:
        checkImgNum()
        if erase(imgNum):
            print "Image number %d erased" % imgNum
    elif sys.argv[2] in ["-b", "--boot"]:
        if boot():
            print "Command sent"
    elif sys.argv[2] in ["-r", "--reprogram"]:
        checkImgNum()
        if reprogram(imgNum):
            print "Command sent"
    elif sys.argv[2] in ["-d", "--disseminate"]:
        checkImgNum()
        if disseminate(imgNum):
            print "Command sent"
    elif sys.argv[2] in ["-dr", "--disseminate-and-reboot"]:
        checkImgNum()
        if disseminateAndReboot(imgNum):
            print "Command sent"
    elif sys.argv[2] in ["-s", "--stop"]:
        if stop():
            print "Command sent"
    elif sys.argv[2] in ["-ls", "--local-stop"]:
        if localstop():
            print "Command sent"
    else:
        print_usage()

else:
    print_usage()

sys.exit()

