|
from tensorboardX import SummaryWriter |
|
import unittest |
|
from tensorboardX.record_writer import S3RecordWriter, make_valid_tf_name |
|
import os |
|
import boto3 |
|
from moto import mock_s3 |
|
|
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "foobar_key") |
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "foobar_secret") |
|
|
|
|
|
class RecordWriterTest(unittest.TestCase): |
|
@mock_s3 |
|
def test_record_writer_s3(self): |
|
client = boto3.client('s3', region_name='us-east-1') |
|
client.create_bucket(Bucket='this') |
|
writer = S3RecordWriter('s3://this/is/apen') |
|
bucket, path = writer.bucket_and_path() |
|
assert bucket == 'this' |
|
assert path == 'is/apen' |
|
writer.write(bytes(42)) |
|
writer.flush() |
|
|
|
def test_make_valid_tf_name(self): |
|
newname = make_valid_tf_name('$ave/&sound') |
|
assert newname == '._ave/_sound' |
|
|