# Fixes a "NameError: name 'sharded_tensor' is not defined" error # for the test_named_params_with_sharded_tensor test # See https://github.com/pytorch/pytorch/pull/73309 From 012d490ed76d8af8538d310a508b0e09a91b7632 Mon Sep 17 00:00:00 2001 From: wanchaol Date: Wed, 23 Feb 2022 12:10:39 -0800 Subject: [PATCH] [shard] fix some imports in tests This fix some imports in sharded optimizer tests Differential Revision: [D34427252](https://our.internmc.facebook.com/intern/diff/D34427252/) [ghstack-poisoned] --- .../_shard/sharded_optim/test_sharded_optim.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/distributed/_shard/sharded_optim/test_sharded_optim.py b/test/distributed/_shard/sharded_optim/test_sharded_optim.py index 085c928985eb..d3f1468aea3c 100644 --- a/test/distributed/_shard/sharded_optim/test_sharded_optim.py +++ b/test/distributed/_shard/sharded_optim/test_sharded_optim.py @@ -2,7 +2,10 @@ import torch import torch.optim as optim -import torch.distributed._shard.sharded_tensor +from torch.distributed._shard import ( + sharded_tensor, + shard_parameter +) from copy import deepcopy from torch.distributed._shard.sharding_spec import ( @@ -77,8 +80,8 @@ def shard_parameter(self): ], ) - sharded_tensor.shard_parameter(self.linear1, "weight", rowwise_sharding_spec) - sharded_tensor.shard_parameter(self.linear2, "weight", colwise_sharding_spec) + shard_parameter(self.linear1, "weight", rowwise_sharding_spec) + shard_parameter(self.linear2, "weight", colwise_sharding_spec) def forward(self, inp): return self.linear2(self.gelu(self.linear1(inp)))