package com.tibbo.aggregate.common.security;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.AlgorithmParameters;
import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Base64;
import java.util.Properties;

import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;

import com.tibbo.aggregate.common.Log;
import com.tibbo.aggregate.common.util.StringUtils;
import org.apache.commons.lang3.RandomStringUtils;

public class KeyUtils
{
  public static final String REF_PHRASE = "reference";
  private static final String KEY_FACTORY = "PBKDF2WithHmacSHA512";
  private static final String KEY_SPEC = "AES";
  private static final String CIPHER = "AES/CBC/PKCS5Padding";
  private static final String ENCRYPTION_DELIMITER = ":";
  private static final Path INTERNAL_PROPERTIES = Paths.get(".ku");

  // Legacy values are kept for the backward compatibility
  private static final String LEGACY_SALT = "s011td";
  private static final char[] LEGACY_INTERNAL_KEY = new char[] { 'i', 'B', 'o', 'h', 'S', 'h', '2', 'o', 'h', 'l', '2', 'i' };
  private static final String INTERNAL_KEY = "internal_key";
  private static final String INTERNAL_SALT = "internal_salt";
  private static final int KEY_LEN = 12;
  private static final int SALT_LEN = 6;
  
  private static SecretKeySpec encryptKey;
  private static SecretKeySpec decryptKey;
  private static SecretKeySpec internalKey;
  private static boolean initialized = false;
  
  public static void generateInternalKey() throws IOException
  {
    if (Files.exists(INTERNAL_PROPERTIES))
    {
      init();
      return;
    }
    String key = RandomStringUtils.randomAlphanumeric(KEY_LEN);
    String salt = RandomStringUtils.randomAlphanumeric(SALT_LEN);
    Properties properties = new Properties();
    properties.setProperty(INTERNAL_KEY, key);
    properties.setProperty(INTERNAL_SALT, salt);
    try (OutputStream os = Files.newOutputStream(INTERNAL_PROPERTIES))
    {
      properties.store(os, "FOR INTERNAL USAGE ONLY. DO NOT MODIFY.");
    }
    
    init();
  }
  
  // Should be called before the usage of the class
  public static void init()
  {
    try
    {
      internalKey = getSecretKeySpec();
    }
    catch (Throwable e)
    {
      internalKey = null;
    }
    
    initialized = true;
  }
  
  private static SecretKeySpec getSecretKeySpec() throws GeneralSecurityException
  {
    try
    {
      return getSecretKeySpec(getInternalKeyPassword(), getInternalKeySalt());
    }
    catch (IOException e)
    {
      throw new GeneralSecurityException(e);
    }
  }
  
  private static SecretKeySpec getSecretKeySpec(char[] keyPassword, byte[] salt) throws GeneralSecurityException
  {
    SecretKeyFactory keyFactory = SecretKeyFactory.getInstance(KEY_FACTORY);
    PBEKeySpec keySpec;
    keySpec = new PBEKeySpec(keyPassword, salt, 65536, 256);
    SecretKey keyTmp = keyFactory.generateSecret(keySpec);
    return new SecretKeySpec(keyTmp.getEncoded(), KEY_SPEC);
  }
  
  private static byte[] getInternalKeySalt() throws IOException
  {
    String salt = getInternalPropertyValue(INTERNAL_SALT);
    if (salt == null)
    {
      throw new SecurityException("Key substitution detected");
    }
    return salt.getBytes(StandardCharsets.UTF_8);
  }
  
  private static char[] getInternalKeyPassword() throws IOException
  {
    String key = getInternalPropertyValue(INTERNAL_KEY);
    if (key == null)
    {
      throw new SecurityException("Key substitution detected");
    }
    return key.toCharArray();
  }
  
  private static String getInternalPropertyValue(String propertyName) throws IOException
  {
    Properties properties = new Properties();
    if (Files.exists(INTERNAL_PROPERTIES))
    {
      try (InputStream inputStream = Files.newInputStream(INTERNAL_PROPERTIES))
      {
        properties.load(inputStream);
      }
    }
    return properties.getProperty(propertyName);
  }
  
  public static String encryptByteArrayToString(byte[] string, SecretKeySpec keySpec) throws GeneralSecurityException
  {
    Cipher pbeCipher = Cipher.getInstance(CIPHER);
    pbeCipher.init(Cipher.ENCRYPT_MODE, keySpec);
    AlgorithmParameters parameters = pbeCipher.getParameters();
    IvParameterSpec ivParameterSpec = parameters.getParameterSpec(IvParameterSpec.class);
    byte[] cryptoText = pbeCipher.doFinal(string);
    byte[] iv = ivParameterSpec.getIV();
    return Base64.getEncoder().encodeToString(iv) + ENCRYPTION_DELIMITER + Base64.getEncoder().encodeToString(cryptoText);
  }
  
  public static byte[] decryptStringToByteArray(String string, SecretKeySpec keySpec) throws GeneralSecurityException
  {
    String iv = string.split(ENCRYPTION_DELIMITER)[0];
    String property = string.split(ENCRYPTION_DELIMITER)[1];
    Cipher pbeCipher = Cipher.getInstance(CIPHER);
    pbeCipher.init(Cipher.DECRYPT_MODE, keySpec, new IvParameterSpec(Base64.getDecoder().decode(iv)));
    return pbeCipher.doFinal(Base64.getDecoder().decode(property));
  }
  
  public static String encryptStringBy(String string, SecretKeySpec key) throws GeneralSecurityException
  {
    if (key == null || StringUtils.isEmpty(string))
    {
      return string;
    }
    
    return encryptByteArrayToString(string.getBytes(StandardCharsets.UTF_8), key);
  }
  
  public static String decryptStringBy(String string, SecretKeySpec key) throws GeneralSecurityException
  {
    if (key == null || StringUtils.isEmpty(string))
    {
      return string;
    }
    
    return new String(decryptStringToByteArray(string, key), StandardCharsets.UTF_8);
  }
  
  public static String encryptString(String string) throws SecurityException
  {
    try
    {
      return encryptStringBy(string, getEncryptKey());
    }
    catch (Exception ex)
    {
      Log.SECURITY.fatal("Re-encryption failed: " + ex.getMessage(), ex);
      throw new SecurityException("Encryption failure", ex);
    }
  }
  
  public static String decryptString(String string) throws SecurityException
  {
    try
    {
      return decryptStringBy(string, getDecryptKey());
    }
    catch (Exception ex)
    {
      Log.SECURITY.fatal("Re-encryption failed: " + ex.getMessage(), ex);
      throw new SecurityException("Decryption failure", ex);
    }
  }
  
  public static SecretKeySpec getEncryptKey()
  {
    return encryptKey;
  }
  
  public static SecretKeySpec getDecryptKey()
  {
    return decryptKey;
  }
  
  public static SecretKeySpec getLegacyInternalKey() throws GeneralSecurityException
  {
    return getSecretKeySpec(LEGACY_INTERNAL_KEY, LEGACY_SALT.getBytes(StandardCharsets.UTF_8));
  }
  
  public static SecretKeySpec getInternalKey()
  {
    if (!initialized)
    {
      init();
    }
    return internalKey;
  }
  
  public static String getEncryptedInternalKey() throws SecurityException
  {
    return encryptKey(getInternalKey().getEncoded());
  }
  
  public static void setKeys(byte[] oldEncryptionKey, byte[] newEncryptionKey) throws SecurityException
  {
    try
    {
      if (newEncryptionKey != null)
        encryptKey = new SecretKeySpec(newEncryptionKey, KEY_SPEC);
      if (oldEncryptionKey != null)
        decryptKey = new SecretKeySpec(oldEncryptionKey, KEY_SPEC);
    }
    catch (Throwable ex)
    {
      throw new SecurityException("Invalid key", ex);
    }
  }
  
  public static void eraseKeys()
  {
    encryptKey = null;
    decryptKey = null;
  }
  
  public static byte[] getRandomKey()
  {
    try
    {
      KeyGenerator keyGenerator = KeyGenerator.getInstance(KeyUtils.KEY_SPEC);
      keyGenerator.init(128);
      SecretKey secretKey = keyGenerator.generateKey();
      return secretKey.getEncoded();
    }
    catch (NoSuchAlgorithmException ex)
    {
      return null;
    }
  }
  
  public static String encryptKey(byte[] key)
  {
    if (key == null)
    {
      return null;
    }
    
    if (getInternalKey() == null)
    {
      return Base64.getEncoder().encodeToString(key);
    }
    
    try
    {
      return encryptByteArrayToString(key, getInternalKey());
    }
    catch (Exception ex)
    {
      throw new SecurityException("Key encryption failure", ex);
    }
  }

  public static byte[] decryptKey(String key) throws SecurityException
  {
    if (key == null)
    {
      return null;
    }
    
    if (getInternalKey() == null)
    {
      return Base64.getDecoder().decode(key);
    }
    
    try
    {
      return decryptStringToByteArray(key, getInternalKey());
    }
    catch (Exception ex)
    {
      Log.SECURITY.debug("Key decryption failure " + key, ex);
      throw new SecurityException("Key decryption failure", ex);
    }
    
  }
  
  public static String encryptPw(String pw) throws SecurityException
  {
    try
    {
      return encryptStringBy(pw, getInternalKey());
    }
    catch (Exception ex)
    {
      Log.SECURITY.warn("Can't encrypt password", ex);
      throw new SecurityException("Password encryption failure", ex);
    }
  }
  
  public static String decryptPw(String pw) throws SecurityException
  {
    try
    {
      return decryptStringBy(pw, getInternalKey());
    }
    catch (Exception ex)
    {
      Log.SECURITY.warn("Can't decrypt password", ex);
      throw new SecurityException("Password decryption failure", ex);
    }
  }
}
