Add w2nn.Print for debug
This commit is contained in:
parent
0fe21eef70
commit
1dc32aaa89
15
lib/Print.lua
Normal file
15
lib/Print.lua
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
local Print, parent = torch.class('w2nn.Print','nn.Module')
|
||||||
|
|
||||||
|
function Print:__init()
|
||||||
|
parent.__init(self)
|
||||||
|
end
|
||||||
|
function Print:updateOutput(input)
|
||||||
|
print(input:size())
|
||||||
|
self.output:resizeAs(input)
|
||||||
|
self.output:copy(input)
|
||||||
|
return self.output
|
||||||
|
end
|
||||||
|
function Print:updateGradInput(input, gradOutput)
|
||||||
|
self.gradInput:resizeAs(GradOutput)
|
||||||
|
return self.gradInput
|
||||||
|
end
|
16
lib/PrintTable.lua
Normal file
16
lib/PrintTable.lua
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
local PrintTable, parent = torch.class('w2nn.PrintTable','nn.Module')
|
||||||
|
|
||||||
|
function PrintTable:__init(id)
|
||||||
|
parent.__init(self)
|
||||||
|
self.id = id
|
||||||
|
end
|
||||||
|
function PrintTable:updateOutput(input)
|
||||||
|
print("----", self.id)
|
||||||
|
print(input)
|
||||||
|
self.output = input
|
||||||
|
return self.output
|
||||||
|
end
|
||||||
|
function PrintTable:updateGradInput(input, gradOutput)
|
||||||
|
self.gradInput = gradOutput
|
||||||
|
return self.gradInput
|
||||||
|
end
|
|
@ -75,5 +75,7 @@ else
|
||||||
require 'InplaceClip01'
|
require 'InplaceClip01'
|
||||||
require 'L1Criterion'
|
require 'L1Criterion'
|
||||||
require 'ShakeShakeTable'
|
require 'ShakeShakeTable'
|
||||||
|
require 'PrintTable'
|
||||||
|
require 'Print'
|
||||||
return w2nn
|
return w2nn
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue