ComputationGraph

public class ComputationGraph: Graph

Intro

The implementation of Graph. A ComputationGraph just a set of symbols and connections between them. The graph computes the output tensors stage by stage.

A typical usage:


let graph = ComputationGraph()

let a = graph.tensor("A", shape: TensorShape(dataType: .float, shape: [2,3]))
let b = graph.tensor("B", shape: TensorShape(dataType: .float, shape: [2,3]))
let (c, op) = graph.operation("", inputSymbols: [a, b], op: PowOperator()).first

ComputationGraph V.S. Model

ComputationGraph is low-level abstraction of machine learning models. It has two basic funtions:

  • forward. Compute results from input to output
  • backward. Compute and update data symbols that are differentiable use optimizer User needs to call forward(:) and backward(:) manually to do training and ComputationGraph is not aware of loss function. If you want to use loss function if a graph, you need add it into this graph as an operator symbol.

Model is a higher level abstraction inherited from ComputationGraph. It has all functions ComputationGraph has and beyond that:

  • Loss function. Model could setup loss function to do the backward training. ‘ComputationGraph’
  • High level training functions. Model could automatically repeat forward(:) and backward(:) util reaches conditions users setup like max number of epoch, early stop etc.
  • High level prediction functions.
  • Prepare before forwarding. Usually should be called just before 1st fowarding. However, if user chagne graph strcture, this should be called again.

    Declaration

    Swift

    public func forwardPrepare()
  • Use topology sorting to sort all symbols into different stages. Store results in symbolStages. Currently, only supports directed acyclic graph (i.e. no directed cycles).

    Note

    If there is no input operator symbols, fatalError will be raised.

    Algorithm

    Use DFS(Depth-First Search) to construct the staged information.

    1. Find 1st stage operator symbols by looking whose inbound symbols having no inbounds:
    2. For each op symbol in 1st stage, do visiting: I. Mark this op symbol with depth 0 and its all inbounds and outbounds symbols depth as 0; II. Use DFS to starting from this op symbol following outBounds path. Each time meet a op symbol, mark this opsymbol with depth +1. like I Each time gos deeper, depth plus 1. III. If the visiting symbol is already in the statck, set depth to max(this visiting depth, existing depth).
    3. Then according the list information, add symbols in same stage into opSymbolsStages.

    Declaration

    Swift

    public func sortGraph()
  • Declaration

    Swift

    internal func visitSymbol(_ opSymbol: SerranoOperatorSymbol, opDepthInfo: inout [SerranoOperatorSymbol: Int], currentDepth: Int)

    Parameters

    symbol

    symbolDepthList

    currentDepth

  • This functions verifies:

    • All symbols should have been binded to data
      • Shape compatibility between connected operators and tensors

    If any incorrectness found, return false and associated message.

    Note

    This function assumes graph already been sorted and attribute symbolStages already has stage info.

    Declaration

    Swift

    public func verifyGraph() -> (valid: Bool, msg: String)

    Return Value

    valid represent if passing the verification and error msg if has

  • This function check a sorted graph’s every path to see if all connected operators and symbols could match each other considering TensorShape.

    Checking

       For each operator symbol in this stage:
           I.   Get input tensors from inBounds and output tensors from outBounds.
             Assign tensors to operator's inputTensors and outputTensors.
           II.  Call inputOutputTensorCheck().
               a). If return true, mark operator's disableInputOutputCheck to true.
                   Continue.
               b). If return false, return false and related errror msg.
    

    Note

    This internal function assumes the graph’s all symbols have been binded.

    Declaration

    Swift

    public func checkShapeChain() -> (valid: Bool, msg: String)

    Return Value

    validation result and error message if has

  • Check if all user data source symbols have been binded.

    Declaration

    Swift

    public func userInputBindCheck() -> (valid: Bool, msg: String)

    Return Value

    if pass checking and related error msg if has.

  • Allocate tensors for all tensor symbols whose bined data are nil.

    Declaration

    Swift

    public func allocateAllTensors()
  • Allocate a huge tensor first that are nil

    Declaration

    Swift

    public func allocateAllTensorsBigOne()
  • Stage by stage, run all operators.

    Algorithm

    for stage i in [0, n]:
           run all operators in stage i simutaneously
    

    Note

    Graph should have been sorted. Else fatalError will be raised.

    Declaration

    Swift

    internal func stageOrderCalculate(mode: OperatorComputationMode)
  • Prepare workd.

    Declaration

    Swift

    internal func backwardPrepare()
  • Allocate tensors and initialize scalar values for grads of data symbol.

    Initialization strategy

    We will only look at operators with true values for attribute enabledParameterUpdate, and will initialize currentGrad of these operator symbols’ inbound data symbols with true value for attribute updateble.

    Declaration

    Swift

    internal func allocateTensorsForGrads()
  • This function do updating during backward training.

    For operator symbols at each stage starting from last to first, calculate the grads for all inbound data symbols. Then update values for those are updateble.

    Declaration

    Swift

    internal func dataSymbolsUpdate(_ mode: OperatorComputationMode)
  • This function does windup work after updating all grads in this graph.

    Declaration

    Swift

    internal func windup()