// Copyright (c) 2022, the Dart project authors.  Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.

import 'dart:typed_data';

import 'instructions.dart';
import 'serialize.dart';
import 'types.dart';

/// A Wasm module.
///
/// Serves as a builder for building new modules.
class Module with SerializerMixin {
  final List<int>? watchPoints;

  final Map<_FunctionTypeKey, FunctionType> functionTypeMap = {};

  final List<DefType> defTypes = [];
  final List<BaseFunction> functions = [];
  final List<Table> tables = [];
  final List<Memory> memories = [];
  final List<Tag> tags = [];
  final List<DataSegment> dataSegments = [];
  final List<Global> globals = [];
  final List<Export> exports = [];
  BaseFunction? startFunction = null;

  bool anyFunctionsDefined = false;
  bool anyMemoriesDefined = false;
  bool anyGlobalsDefined = false;
  bool dataReferencedFromGlobalInitializer = false;

  int functionNameCount = 0;

  /// Create a new, initially empty, module.
  ///
  /// The [watchPoints] is a list of byte offsets within the final module of
  /// bytes to watch. When the module is serialized, the stack traces leading to
  /// the production of all watched bytes are printed. This can be used to debug
  /// runtime errors happening at specific offsets within the module.
  Module({this.watchPoints}) {
    if (watchPoints != null) {
      SerializerMixin.traceEnabled = true;
    }
  }

  /// All module imports (functions and globals).
  Iterable<Import> get imports => functions
      .whereType<Import>()
      .followedBy(memories.whereType<Import>())
      .followedBy(globals.whereType<Import>());

  /// All functions defined in the module.
  Iterable<DefinedFunction> get definedFunctions =>
      functions.whereType<DefinedFunction>();

  /// All memories defined in the module.
  Iterable<DefinedMemory> get definedMemories =>
      memories.whereType<DefinedMemory>();

  /// All globals defined in the module.
  Iterable<DefinedGlobal> get definedGlobals =>
      globals.whereType<DefinedGlobal>();

  /// Add a new function type to the module.
  ///
  /// All function types are canonicalized, such that identical types become
  /// the same type definition in the module, assuming nominal type identity
  /// of all inputs and outputs.
  ///
  /// Inputs and outputs can't be changed after the function type is created.
  /// This means that recursive function types (without any non-function types
  /// on the recursion path) are not supported.
  FunctionType addFunctionType(
      Iterable<ValueType> inputs, Iterable<ValueType> outputs,
      {HeapType? superType}) {
    final List<ValueType> inputList = List.unmodifiable(inputs);
    final List<ValueType> outputList = List.unmodifiable(outputs);
    final _FunctionTypeKey key = _FunctionTypeKey(inputList, outputList);
    return functionTypeMap.putIfAbsent(key, () {
      final type = FunctionType(inputList, outputList, superType: superType)
        ..index = defTypes.length;
      defTypes.add(type);
      return type;
    });
  }

  /// Add a new struct type to the module.
  ///
  /// Fields can be added later, by adding to the [fields] list. This enables
  /// struct types to be recursive.
  StructType addStructType(String name,
      {Iterable<FieldType>? fields, HeapType? superType}) {
    final type = StructType(name, fields: fields, superType: superType)
      ..index = defTypes.length;
    defTypes.add(type);
    return type;
  }

  /// Add a new array type to the module.
  ///
  /// The element type can be specified later. This enables array types to be
  /// recursive.
  ArrayType addArrayType(String name,
      {FieldType? elementType, HeapType? superType}) {
    final type = ArrayType(name, elementType: elementType, superType: superType)
      ..index = defTypes.length;
    defTypes.add(type);
    return type;
  }

  /// Add a new function to the module with the given function type.
  ///
  /// The [DefinedFunction.body] must be completed (including the terminating
  /// `end`) before the module can be serialized.
  DefinedFunction addFunction(FunctionType type, [String? name]) {
    anyFunctionsDefined = true;
    if (name != null) functionNameCount++;
    final function = DefinedFunction(this, functions.length, type, name);
    functions.add(function);
    return function;
  }

  /// Add a new table to the module.
  Table addTable(int minSize, [int? maxSize]) {
    final table = Table(tables.length, minSize, maxSize);
    tables.add(table);
    return table;
  }

  /// Add a new memory to the module.
  DefinedMemory addMemory(int minSize, [int? maxSize]) {
    anyMemoriesDefined = true;
    final memory = DefinedMemory(memories.length, minSize, maxSize);
    memories.add(memory);
    return memory;
  }

  /// Add a new tag to the module.
  Tag addTag(FunctionType type) {
    final tag = Tag(tags.length, type);
    tags.add(tag);
    return tag;
  }

  /// Add a new data segment to the module.
  ///
  /// Either [memory] and [offset] must be both specified or both omitted. If
  /// they are specified, the segment becomes an *active* segment, otherwise it
  /// becomes a *passive* segment.
  ///
  /// If [initialContent] is specified, it defines the initial content of the
  /// segment. The content can be extended later.
  DataSegment addDataSegment(
      [Uint8List? initialContent, Memory? memory, int? offset]) {
    initialContent ??= Uint8List(0);
    assert((memory != null) == (offset != null));
    assert(memory == null ||
        offset! >= 0 && offset + initialContent.length <= memory.minSize);
    final DataSegment data =
        DataSegment(dataSegments.length, initialContent, memory, offset);
    dataSegments.add(data);
    return data;
  }

  /// Add a global variable to the module.
  ///
  /// The [DefinedGlobal.initializer] must be completed (including the
  /// terminating `end`) before the module can be serialized.
  DefinedGlobal addGlobal(GlobalType type) {
    anyGlobalsDefined = true;
    final global = DefinedGlobal(this, globals.length, type);
    globals.add(global);
    return global;
  }

  /// Import a function into the module.
  ///
  /// All imported functions must be specified before any functions are declared
  /// using [Module.addFunction].
  ImportedFunction importFunction(String module, String name, FunctionType type,
      [String? functionName]) {
    if (anyFunctionsDefined) {
      throw "All function imports must be specified before any definitions.";
    }
    if (functionName != null) functionNameCount++;
    final function =
        ImportedFunction(module, name, functions.length, type, functionName);
    functions.add(function);
    return function;
  }

  /// Import a memory into the module.
  ///
  /// All imported memories must be specified before any memories are declared
  /// using [Module.addMemory].
  ImportedMemory importMemory(String module, String name, int minSize,
      [int? maxSize]) {
    if (anyMemoriesDefined) {
      throw "All memory imports must be specified before any definitions.";
    }
    final memory =
        ImportedMemory(module, name, memories.length, minSize, maxSize);
    memories.add(memory);
    return memory;
  }

  /// Import a global variable into the module.
  ///
  /// All imported globals must be specified before any globals are declared
  /// using [Module.addGlobal].
  ImportedGlobal importGlobal(String module, String name, GlobalType type) {
    if (anyGlobalsDefined) {
      throw "All global imports must be specified before any definitions.";
    }
    final global = ImportedGlobal(module, name, functions.length, type);
    globals.add(global);
    return global;
  }

  void _addExport(Export export) {
    assert(!exports.any((e) => e.name == export.name), export.name);
    exports.add(export);
  }

  /// Export a function from the module.
  ///
  /// All exports must have unique names.
  void exportFunction(String name, BaseFunction function) {
    function.exportedName = name;
    _addExport(FunctionExport(name, function));
  }

  /// Export a global variable from the module.
  ///
  /// All exports must have unique names.
  void exportGlobal(String name, Global global) {
    exports.add(GlobalExport(name, global));
  }

  /// Serialize the module to its binary representation.
  Uint8List encode({bool emitNameSection: true}) {
    // Wasm module preamble: magic number, version 1.
    writeBytes(const [0x00, 0x61, 0x73, 0x6D, 0x01, 0x00, 0x00, 0x00]);
    TypeSection(this).serialize(this);
    ImportSection(this).serialize(this);
    FunctionSection(this).serialize(this);
    TableSection(this).serialize(this);
    MemorySection(this).serialize(this);
    TagSection(this).serialize(this);
    if (dataReferencedFromGlobalInitializer) {
      DataCountSection(this).serialize(this);
    }
    GlobalSection(this).serialize(this);
    ExportSection(this).serialize(this);
    StartSection(this).serialize(this);
    ElementSection(this).serialize(this);
    if (!dataReferencedFromGlobalInitializer) {
      DataCountSection(this).serialize(this);
    }
    CodeSection(this).serialize(this);
    DataSection(this).serialize(this);
    if (emitNameSection) {
      NameSection(this).serialize(this);
    }
    return data;
  }
}

class _FunctionTypeKey {
  final List<ValueType> inputs;
  final List<ValueType> outputs;

  _FunctionTypeKey(this.inputs, this.outputs);

  @override
  bool operator ==(Object other) {
    if (other is! _FunctionTypeKey) return false;
    if (inputs.length != other.inputs.length) return false;
    if (outputs.length != other.outputs.length) return false;
    for (int i = 0; i < inputs.length; i++) {
      if (inputs[i] != other.inputs[i]) return false;
    }
    for (int i = 0; i < outputs.length; i++) {
      if (outputs[i] != other.outputs[i]) return false;
    }
    return true;
  }

  @override
  int get hashCode {
    int inputHash = 13;
    for (var input in inputs) {
      inputHash = inputHash * 17 + input.hashCode;
    }
    int outputHash = 23;
    for (var output in outputs) {
      outputHash = outputHash * 29 + output.hashCode;
    }
    return (inputHash * 2 + 1) * (outputHash * 2 + 1);
  }
}

/// An (imported or defined) Wasm function.
abstract class BaseFunction {
  final int index;
  final FunctionType type;
  final String? functionName;
  String? exportedName;

  BaseFunction(this.index, this.type, this.functionName);
}

/// A function defined in the module.
class DefinedFunction extends BaseFunction
    with SerializerMixin
    implements Serializable {
  /// All local variables defined in the function, including its inputs.
  final List<Local> locals = [];

  /// The body of the function.
  late final Instructions body;

  DefinedFunction(Module module, int index, FunctionType type,
      [String? functionName])
      : super(index, type, functionName) {
    for (ValueType paramType in type.inputs) {
      addLocal(paramType);
    }
    body = Instructions(module, type.outputs, locals: locals);
  }

  /// Add a local variable to the function.
  Local addLocal(ValueType type) {
    Local local = Local(locals.length, type);
    locals.add(local);
    return local;
  }

  @override
  void serialize(Serializer s) {
    // Serialize locals internally first in order to compute the total size of
    // the serialized data.
    int paramCount = type.inputs.length;
    int entries = 0;
    for (int i = paramCount + 1; i <= locals.length; i++) {
      if (i == locals.length || locals[i - 1].type != locals[i].type) entries++;
    }
    writeUnsigned(entries);
    int start = paramCount;
    for (int i = paramCount + 1; i <= locals.length; i++) {
      if (i == locals.length || locals[i - 1].type != locals[i].type) {
        writeUnsigned(i - start);
        write(locals[i - 1].type);
        start = i;
      }
    }

    // Bundle locals and body
    assert(body.isComplete);
    s.writeUnsigned(data.length + body.data.length);
    s.writeData(this);
    s.writeData(body);
  }

  @override
  String toString() => exportedName ?? "#$index";
}

/// A local variable defined in a function.
class Local {
  final int index;
  final ValueType type;

  Local(this.index, this.type);

  @override
  String toString() => "$index";
}

/// A table in a module.
class Table implements Serializable {
  final int index;
  final int minSize;
  final int? maxSize;
  final List<BaseFunction?> elements;

  Table(this.index, this.minSize, this.maxSize)
      : elements = List.filled(minSize, null);

  void setElement(int index, BaseFunction function) {
    elements[index] = function;
  }

  @override
  void serialize(Serializer s) {
    s.writeByte(0x70); // funcref
    if (maxSize == null) {
      s.writeByte(0x00);
      s.writeUnsigned(minSize);
    } else {
      s.writeByte(0x01);
      s.writeUnsigned(minSize);
      s.writeUnsigned(maxSize!);
    }
  }
}

/// A memory in a module.
class Memory {
  final int index;
  final int minSize;
  final int? maxSize;

  Memory(this.index, this.minSize, [this.maxSize]);

  void _serializeLimits(Serializer s) {
    if (maxSize == null) {
      s.writeByte(0x00);
      s.writeUnsigned(minSize);
    } else {
      s.writeByte(0x01);
      s.writeUnsigned(minSize);
      s.writeUnsigned(maxSize!);
    }
  }
}

class DefinedMemory extends Memory implements Serializable {
  DefinedMemory(int index, int minSize, int? maxSize)
      : super(index, minSize, maxSize);

  @override
  void serialize(Serializer s) => _serializeLimits(s);
}

/// A tag in a module.
class Tag implements Serializable {
  final int index;
  final FunctionType type;

  Tag(this.index, this.type);

  @override
  void serialize(Serializer s) {
    // 0 byte for exception.
    s.writeByte(0x00);
    s.write(type);
  }

  String toString() => "#$index";
}

/// A data segment in a module.
class DataSegment implements Serializable {
  final int index;
  final BytesBuilder content;
  final Memory? memory;
  final int? offset;

  DataSegment(this.index, Uint8List initialContent, this.memory, this.offset)
      : content = BytesBuilder()..add(initialContent);

  bool get isActive => memory != null;
  bool get isPassive => memory == null;

  int get length => content.length;

  /// Append content to the data segment.
  void append(Uint8List data) {
    content.add(data);
    assert(isPassive ||
        offset! >= 0 && offset! + content.length <= memory!.minSize);
  }

  @override
  void serialize(Serializer s) {
    if (memory != null) {
      // Active segment
      if (memory!.index == 0) {
        s.writeByte(0x00);
      } else {
        s.writeByte(0x02);
        s.writeUnsigned(memory!.index);
      }
      s.writeByte(0x41); // i32.const
      s.writeSigned(offset!);
      s.writeByte(0x0B); // end
    } else {
      // Passive segment
      s.writeByte(0x01);
    }
    s.writeUnsigned(content.length);
    s.writeBytes(content.toBytes());
  }
}

/// An (imported or defined) global variable in a module.
abstract class Global {
  final int index;
  final GlobalType type;

  Global(this.index, this.type);

  @override
  String toString() => "$index";
}

/// A global variable defined in the module.
class DefinedGlobal extends Global implements Serializable {
  final Instructions initializer;

  DefinedGlobal(Module module, int index, GlobalType type)
      : initializer =
            Instructions(module, [type.type], isGlobalInitializer: true),
        super(index, type);

  @override
  void serialize(Serializer s) {
    assert(initializer.isComplete);
    s.write(type);
    s.writeData(initializer);
  }
}

/// Any import (function or global).
abstract class Import implements Serializable {
  String get module;
  String get name;
}

/// An imported function.
class ImportedFunction extends BaseFunction implements Import {
  final String module;
  final String name;

  ImportedFunction(this.module, this.name, int index, FunctionType type,
      [String? functionName])
      : super(index, type, functionName);

  @override
  void serialize(Serializer s) {
    s.writeName(module);
    s.writeName(name);
    s.writeByte(0x00);
    s.writeUnsigned(type.index);
  }

  @override
  String toString() => "$module.$name";
}

/// An imported memory.
class ImportedMemory extends Memory implements Import {
  final String module;
  final String name;

  ImportedMemory(this.module, this.name, int index, int minSize, int? maxSize)
      : super(index, minSize, maxSize);

  @override
  void serialize(Serializer s) {
    s.writeName(module);
    s.writeName(name);
    s.writeByte(0x02);
    _serializeLimits(s);
  }
}

/// An imported global variable.
class ImportedGlobal extends Global implements Import {
  final String module;
  final String name;

  ImportedGlobal(this.module, this.name, int index, GlobalType type)
      : super(index, type);

  @override
  void serialize(Serializer s) {
    s.writeName(module);
    s.writeName(name);
    s.writeByte(0x03);
    s.write(type);
  }
}

abstract class Export implements Serializable {
  final String name;

  Export(this.name);
}

class FunctionExport extends Export {
  final BaseFunction function;

  FunctionExport(String name, this.function) : super(name);

  @override
  void serialize(Serializer s) {
    s.writeName(name);
    s.writeByte(0x00);
    s.writeUnsigned(function.index);
  }
}

class GlobalExport extends Export {
  final Global global;

  GlobalExport(String name, this.global) : super(name);

  @override
  void serialize(Serializer s) {
    s.writeName(name);
    s.writeByte(0x03);
    s.writeUnsigned(global.index);
  }
}

abstract class Section with SerializerMixin implements Serializable {
  final Module module;

  Section(this.module);

  void serialize(Serializer s) {
    if (isNotEmpty) {
      serializeContents();
      s.writeByte(id);
      s.writeUnsigned(data.length);
      s.writeData(this, module.watchPoints);
    }
  }

  int get id;

  bool get isNotEmpty;

  void serializeContents();
}

class TypeSection extends Section {
  TypeSection(Module module) : super(module);

  @override
  int get id => 1;

  @override
  bool get isNotEmpty => module.defTypes.isNotEmpty;

  @override
  void serializeContents() {
    writeUnsigned(module.defTypes.length);
    for (DefType defType in module.defTypes) {
      defType.serializeDefinition(this);
    }
  }
}

class ImportSection extends Section {
  ImportSection(Module module) : super(module);

  @override
  int get id => 2;

  @override
  bool get isNotEmpty => module.imports.isNotEmpty;

  @override
  void serializeContents() {
    writeList(module.imports.toList());
  }
}

class FunctionSection extends Section {
  FunctionSection(Module module) : super(module);

  @override
  int get id => 3;

  @override
  bool get isNotEmpty => module.definedFunctions.isNotEmpty;

  @override
  void serializeContents() {
    writeUnsigned(module.definedFunctions.length);
    for (var function in module.definedFunctions) {
      writeUnsigned(function.type.index);
    }
  }
}

class TableSection extends Section {
  TableSection(Module module) : super(module);

  @override
  int get id => 4;

  @override
  bool get isNotEmpty => module.tables.isNotEmpty;

  @override
  void serializeContents() {
    writeList(module.tables);
  }
}

class MemorySection extends Section {
  MemorySection(Module module) : super(module);

  @override
  int get id => 5;

  @override
  bool get isNotEmpty => module.definedMemories.isNotEmpty;

  @override
  void serializeContents() {
    writeList(module.definedMemories.toList());
  }
}

class TagSection extends Section {
  TagSection(Module module) : super(module);

  @override
  int get id => 13;

  @override
  bool get isNotEmpty => module.tags.isNotEmpty;

  @override
  void serializeContents() {
    writeList(module.tags);
  }
}

class GlobalSection extends Section {
  GlobalSection(Module module) : super(module);

  @override
  int get id => 6;

  @override
  bool get isNotEmpty => module.definedGlobals.isNotEmpty;

  @override
  void serializeContents() {
    writeList(module.definedGlobals.toList());
  }
}

class ExportSection extends Section {
  ExportSection(Module module) : super(module);

  @override
  int get id => 7;

  @override
  bool get isNotEmpty => module.exports.isNotEmpty;

  @override
  void serializeContents() {
    writeList(module.exports);
  }
}

class StartSection extends Section {
  StartSection(Module module) : super(module);

  @override
  int get id => 8;

  @override
  bool get isNotEmpty => module.startFunction != null;

  @override
  void serializeContents() {
    writeUnsigned(module.startFunction!.index);
  }
}

class _Element implements Serializable {
  final Table table;
  final int startIndex;
  final List<BaseFunction> entries = [];

  _Element(this.table, this.startIndex);

  @override
  void serialize(Serializer s) {
    s.writeUnsigned(table.index);
    s.writeByte(0x41); // i32.const
    s.writeSigned(startIndex);
    s.writeByte(0x0B); // end
    s.writeUnsigned(entries.length);
    for (var entry in entries) {
      s.writeUnsigned(entry.index);
    }
  }
}

class ElementSection extends Section {
  ElementSection(Module module) : super(module);

  @override
  int get id => 9;

  @override
  bool get isNotEmpty =>
      module.tables.any((table) => table.elements.any((e) => e != null));

  @override
  void serializeContents() {
    // Group nonempty element entries into contiguous stretches and serialize
    // each stretch as an element.
    List<_Element> elements = [];
    for (Table table in module.tables) {
      _Element? current = null;
      for (int i = 0; i < table.elements.length; i++) {
        BaseFunction? function = table.elements[i];
        if (function != null) {
          if (current == null) {
            current = _Element(table, i);
            elements.add(current);
          }
          current.entries.add(function);
        } else {
          current = null;
        }
      }
    }
    writeList(elements);
  }
}

class DataCountSection extends Section {
  DataCountSection(Module module) : super(module);

  @override
  int get id => 12;

  @override
  bool get isNotEmpty => module.dataSegments.isNotEmpty;

  @override
  void serializeContents() {
    writeUnsigned(module.dataSegments.length);
  }
}

class CodeSection extends Section {
  CodeSection(Module module) : super(module);

  @override
  int get id => 10;

  @override
  bool get isNotEmpty => module.definedFunctions.isNotEmpty;

  @override
  void serializeContents() {
    writeList(module.definedFunctions.toList());
  }
}

class DataSection extends Section {
  DataSection(Module module) : super(module);

  @override
  int get id => 11;

  @override
  bool get isNotEmpty => module.dataSegments.isNotEmpty;

  @override
  void serializeContents() {
    writeList(module.dataSegments);
  }
}

abstract class CustomSection extends Section {
  CustomSection(Module module) : super(module);

  @override
  int get id => 0;
}

class NameSection extends CustomSection {
  NameSection(Module module) : super(module);

  @override
  bool get isNotEmpty => module.functionNameCount > 0;

  @override
  void serializeContents() {
    writeName("name");
    var functionNameSubsection = _NameSubsection();
    functionNameSubsection.writeUnsigned(module.functionNameCount);
    for (int i = 0; i < module.functions.length; i++) {
      String? functionName = module.functions[i].functionName;
      if (functionName != null) {
        functionNameSubsection.writeUnsigned(i);
        functionNameSubsection.writeName(functionName);
      }
    }
    writeByte(1); // Function names subsection
    writeUnsigned(functionNameSubsection.data.length);
    writeData(functionNameSubsection);
  }
}

class _NameSubsection with SerializerMixin {}
