diff --git a/lib/Print.lua b/lib/Print.lua new file mode 100644 index 0000000..83a9ffd --- /dev/null +++ b/lib/Print.lua @@ -0,0 +1,15 @@ +local Print, parent = torch.class('w2nn.Print','nn.Module') + +function Print:__init() + parent.__init(self) +end +function Print:updateOutput(input) + print(input:size()) + self.output:resizeAs(input) + self.output:copy(input) + return self.output +end +function Print:updateGradInput(input, gradOutput) + self.gradInput:resizeAs(GradOutput) + return self.gradInput +end diff --git a/lib/PrintTable.lua b/lib/PrintTable.lua new file mode 100644 index 0000000..18aed20 --- /dev/null +++ b/lib/PrintTable.lua @@ -0,0 +1,16 @@ +local PrintTable, parent = torch.class('w2nn.PrintTable','nn.Module') + +function PrintTable:__init(id) + parent.__init(self) + self.id = id +end +function PrintTable:updateOutput(input) + print("----", self.id) + print(input) + self.output = input + return self.output +end +function PrintTable:updateGradInput(input, gradOutput) + self.gradInput = gradOutput + return self.gradInput +end diff --git a/lib/w2nn.lua b/lib/w2nn.lua index 91453fb..bf20ec1 100644 --- a/lib/w2nn.lua +++ b/lib/w2nn.lua @@ -75,5 +75,7 @@ else require 'InplaceClip01' require 'L1Criterion' require 'ShakeShakeTable' + require 'PrintTable' + require 'Print' return w2nn end