import sys, os, importlib, cv2
import numpy as np
import torch

if __name__ == "__main__":
    module = importlib.import_module("model.{}".format("carn"))
    net = module.Net(multi_scale=False, 
                     group=1,
                     reduce_upsample=False,
		     scale=8)
    ckpt = "/home/kgrm/data/eccv_modeli/CARN/carn/carn_1297000.pth"
    state_dict = torch.load(ckpt)
    net.load_state_dict(state_dict)

    lr = cv2.imread(sys.argv[1])[:, :, ::-1]
    lr = lr.transpose((2, 0, 1))[None, :].astype("float32") / 255

    sr = net(torch.tensor(lr), 8).detach().numpy() * 255
    sr = sr[0].transpose((1, 2, 0)).round().astype("uint8")
    cv2.imwrite("x8.png", sr[:, :, ::-1])
