#!/usr/bin/python
"""Recovers a disk image from a log file generated by salvage_data.py.
"""
import sys

def coalesce_extents(one, two):
    """ASSUMES the extents overlap."""
    return (min(one[0], two[0]), max(one[1], two[1]))

def extents_overlap(one, two):
    """returns True if they touch or overlap in some way, False if they don't.
    """
    return one[0] <= two[0] <= one[1] or \
        one[0] <= two[1] <= one[1] or \
        two[0] <= one[0] <= two[1] or \
        two[0] <= one[1] <= two[1]

class Extents(object):
    """Tracks the extents that have been covered by the log."""
    def __init__(self):
        self.extents = []

    def add(self, offset, length):
        """Adds an extent at offset of the given length."""
        self.add_extent((offset, offset+length))

    def add_extent(self, extent):
        """Adds an extent."""
        start, end = extent
        if start == end: # nothing to add
            return

        # find any existing extent that overlaps with the one being added.
        for i in range(len(self.extents)):
            current_extent = self.extents[i]
            if extents_overlap(extent, current_extent):
                self.extents.pop(i)
                new_extent = coalesce_extents(current_extent, extent)
                self.add_extent(new_extent)
                return

        # No coalescing needed
        self.extents.append(extent)

    def bytes_covered(self):
        """Returns the number of bytes covered by the extents."""
        return sum([b-a for a, b in self.extents])

    def byte_range(self):
        """Returns a tuple of the starting and ending offsets covered by these
        extents.
        """
        self.extents.sort()
        return (self.extents[0][0], self.extents[-1][1])

    def __str__(self):
        start, end = self.byte_range()
        covered = self.bytes_covered()
        return ' '.join([repr(e) for e in self.extents]) + \
            "\n%s of %s bytes covered (%s remain)" % (covered, end-start,
            end-start-covered)

def report_log(good, bad):
    """let the user know what extents have been accounted for"""
    sys.stderr.write("Good extents: %s\nBad extents: %s\n" % (good, bad))

def write_image_from_log(log, image):
    """Reads from the log file object, and writes the data to the image file
    object.
    """
    good_extents = Extents()
    bad_extents = Extents()
    try:
        while True:
            meta = log.readline().split()
            if not meta:
                break
            if meta[0] == 'D': # data
                offset = long(meta[1])
                length = long(meta[2])
                data = log.read(length)
                if len(data) != length:
                    raise Exception("Short line: %s of %s bytes at offset %s" \
                        % (len(data), length, offset))
                log.read(1) # the extra newline

                sys.stderr.write("writing %s bytes at %s\n" % (length, offset))
                image.seek(offset)
                image.write(data)
                good_extents.add(offset, length)
            elif meta[0] == 'E':
                offset = long(meta[1])
                if len(meta) > 2:
                    length = long(meta[2])
                else:
                    length = 1
                sys.stderr.write("skipping %s bad bytes at %s\n" % (length,
                    offset))
                bad_extents.add(offset, length)
            else:
                raise Exception("Invalid line: %r" % (meta,))
    except:
        report_log(good_extents, bad_extents)
        raise

    report_log(good_extents, bad_extents)
    return good_extents, bad_extents

def write_log_from_image(image, out, good_extents, bad_extents):
    """Write out a concise log file with the same information as the input
    file.
    """
    max_extent_size = 10*1024**2

    extents = [(s, e, 'D') for s, e in good_extents.extents] + \
              [(s, e, 'E') for s, e in bad_extents.extents]
    extents.sort()
    for start, end, state in extents:
        if state == 'E':
            out.write("E %s %s\n" % (start, end - start))
            out.flush()
        elif state == 'D':
            offset = start
            while offset < end:
                image.seek(offset)
                chunk = min(max_extent_size, end-offset)
                data = image.read(chunk)
                if len(data) != chunk:
                    raise Exception("Short read from image file")
                out.write("D %s %s\n%s\n" % (offset, chunk, data))
                out.flush()
                offset += chunk
        else:
            raise Exception("INTERNAL ERROR: Invalid state \"%s\"" % state)


def usage(out):
    """outputs help message"""
    out.write("Syntax Error: %s [-l] <imagefilename>\n"
            "Requires the file to which to write the image, and reads the "
            "recovery log from stdin.\n"
            "/dev/null may be specified as the image filename to just show "
            "summary information.\n"
    )

def main(args):
    """Reads log from standard in, writes an image to the image file"""
    if len(args) < 1:
        usage(sys.stderr)
        sys.exit(1)

    output_log = False
    if args[0] == '-l':
        output_log = True
        if len(args) != 2:
            usage(sys.stderr)
            sys.exit(1)
        filename = args[1]
    else:
        if len(args) != 1:
            usage(sys.stderr)
            sys.exit(1)
        filename = args[0]

    image = open(filename, 'w')

    good, bad = write_image_from_log(sys.stdin, image)
    if output_log:
        write_log_from_image(open(filename, 'r'), sys.stdout, good, bad)

if __name__ == '__main__':
    main(sys.argv[1:])
